adv-optm 1.0.0__tar.gz → 1.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-1.0.0 → adv_optm-1.0.1}/PKG-INFO +1 -1
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/__init__.py +1 -1
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/optim/AdamW_adv.py +19 -19
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/optim/Adopt_adv.py +24 -24
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/optim/Lion_Prodigy_adv.py +8 -8
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/optim/Lion_adv.py +8 -8
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/optim/Prodigy_adv.py +475 -475
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/optim/Simplified_AdEMAMix.py +3 -3
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-1.0.0 → adv_optm-1.0.1}/setup.py +1 -1
- {adv_optm-1.0.0 → adv_optm-1.0.1}/LICENSE +0 -0
- {adv_optm-1.0.0 → adv_optm-1.0.1}/README.md +0 -0
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/util/__init__.py +0 -0
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-1.0.0 → adv_optm-1.0.1}/setup.cfg +0 -0
|
@@ -30,8 +30,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
30
30
|
stochastic_rounding (bool): whether to use stochastic
|
|
31
31
|
rounding for BF16 parameter updates (default: True).
|
|
32
32
|
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
33
|
-
|
|
34
|
-
|
|
33
|
+
grams_moment (bool): whether to use Grams-style updates. (default: False)
|
|
34
|
+
cautious_mask (bool): whether to use cautious masking to align the gradient's
|
|
35
35
|
direction with the first moment's. (default: False)
|
|
36
36
|
use_orthograd (bool): whether to use OrthoGrad. (default: False)
|
|
37
37
|
use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
|
|
@@ -54,7 +54,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
54
54
|
as it gradually introduces the stabilizing slow momentum term. During
|
|
55
55
|
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
56
56
|
the scheduler is disabled. (default: None)
|
|
57
|
-
|
|
57
|
+
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
58
58
|
the uncompressed optimizer. (default: False)
|
|
59
59
|
"""
|
|
60
60
|
|
|
@@ -69,14 +69,14 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
69
69
|
vector_reshape: bool = True,
|
|
70
70
|
stochastic_rounding: bool = True,
|
|
71
71
|
use_atan2: bool = False,
|
|
72
|
-
|
|
73
|
-
|
|
72
|
+
cautious_mask: bool = False,
|
|
73
|
+
grams_moment: bool = False,
|
|
74
74
|
use_orthograd: bool = False,
|
|
75
75
|
use_AdEMAMix: bool = False,
|
|
76
76
|
beta3_ema: float = 0.9999,
|
|
77
77
|
alpha: float = 5.0,
|
|
78
78
|
t_alpha: int | None = None,
|
|
79
|
-
|
|
79
|
+
nnmf_factor: bool = False,
|
|
80
80
|
):
|
|
81
81
|
if not (lr >= 0.0):
|
|
82
82
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -86,9 +86,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
86
86
|
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
87
87
|
if not (weight_decay >= 0.0):
|
|
88
88
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
89
|
-
if
|
|
90
|
-
print("Warning:
|
|
91
|
-
|
|
89
|
+
if cautious_mask and grams_moment:
|
|
90
|
+
print("Warning: cautious is incompatible with grams, Disabling cautious.")
|
|
91
|
+
cautious_mask = False
|
|
92
92
|
|
|
93
93
|
defaults = {
|
|
94
94
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
@@ -97,10 +97,10 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
97
97
|
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
98
98
|
}
|
|
99
99
|
self.stochastic_rounding = stochastic_rounding
|
|
100
|
-
self.
|
|
101
|
-
self.
|
|
100
|
+
self.cautious_mask = cautious_mask
|
|
101
|
+
self.grams_moment = grams_moment
|
|
102
102
|
self.use_AdEMAMix = use_AdEMAMix
|
|
103
|
-
self.factored =
|
|
103
|
+
self.factored = nnmf_factor
|
|
104
104
|
super().__init__(params, defaults)
|
|
105
105
|
|
|
106
106
|
@property
|
|
@@ -151,7 +151,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
151
151
|
if beta1 > 0:
|
|
152
152
|
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
153
153
|
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
154
|
-
if not self.
|
|
154
|
+
if not self.grams_moment:
|
|
155
155
|
packed_d2 = (d2 + 7) // 8
|
|
156
156
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
157
157
|
if self.use_AdEMAMix:
|
|
@@ -192,16 +192,16 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
192
192
|
# Reconstruct momentum from previous step's factors
|
|
193
193
|
if beta1 > 0:
|
|
194
194
|
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
195
|
-
if not self.
|
|
195
|
+
if not self.grams_moment:
|
|
196
196
|
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
197
197
|
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
198
198
|
del unpacked_sign
|
|
199
199
|
# Update momentum in full-size
|
|
200
200
|
grad_reshaped = grad.view(d1, d2)
|
|
201
201
|
mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
|
|
202
|
-
if self.
|
|
202
|
+
if self.grams_moment:
|
|
203
203
|
mt.copy_(grad_reshaped.sign() * mt.abs())
|
|
204
|
-
elif self.
|
|
204
|
+
elif self.cautious_mask:
|
|
205
205
|
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
206
206
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
207
207
|
mt.mul_(mask)
|
|
@@ -240,7 +240,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
240
240
|
|
|
241
241
|
# Compress updated moments and store new factors
|
|
242
242
|
if beta1 > 0:
|
|
243
|
-
if not self.
|
|
243
|
+
if not self.grams_moment:
|
|
244
244
|
state['sign'] = _pack_bools(mt > 0)
|
|
245
245
|
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
246
246
|
del mt
|
|
@@ -257,9 +257,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
257
257
|
if beta1 > 0:
|
|
258
258
|
exp_avg = state['exp_avg']
|
|
259
259
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
260
|
-
if self.
|
|
260
|
+
if self.grams_moment:
|
|
261
261
|
exp_avg = grad.sign() * exp_avg.abs()
|
|
262
|
-
elif self.
|
|
262
|
+
elif self.cautious_mask:
|
|
263
263
|
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
264
264
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
265
265
|
exp_avg.mul_(mask)
|
|
@@ -36,9 +36,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
36
36
|
rounding for BF16 parameter updates (default: True).
|
|
37
37
|
use_atan2 (bool): whether to use an atan2-based normalization, which can
|
|
38
38
|
improve stability by removing the need for `eps`. (default: False)
|
|
39
|
-
|
|
39
|
+
cautious_mask (bool): whether to use cautious masking to align the gradient's
|
|
40
40
|
direction with the first moment's. (default: False)
|
|
41
|
-
|
|
41
|
+
grams_moment (bool): whether to combine the gradient's direction with the
|
|
42
42
|
first moment's magnitude (default: False).
|
|
43
43
|
use_orthograd (bool): whether to use OrthoGrad. (default: False)
|
|
44
44
|
use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
|
|
@@ -65,14 +65,14 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
65
65
|
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
66
66
|
This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
|
|
67
67
|
more responsive, especially for small batch sizes. Enabling this will
|
|
68
|
-
automatically disable `use_AdEMAMix`, `
|
|
68
|
+
automatically disable `use_AdEMAMix`, `cautious_mask`, `grams_moment`,
|
|
69
69
|
and `use_atan2`. (default: False)
|
|
70
70
|
alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
|
|
71
71
|
(only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
|
|
72
72
|
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
73
73
|
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
74
74
|
stability. (default: 100.0)
|
|
75
|
-
|
|
75
|
+
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
76
76
|
the uncompressed optimizer. (default: False)
|
|
77
77
|
"""
|
|
78
78
|
|
|
@@ -87,8 +87,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
87
87
|
vector_reshape: bool = True,
|
|
88
88
|
stochastic_rounding: bool = True,
|
|
89
89
|
use_atan2: bool = False,
|
|
90
|
-
|
|
91
|
-
|
|
90
|
+
cautious_mask: bool = False,
|
|
91
|
+
grams_moment: bool = False,
|
|
92
92
|
use_orthograd: bool = False,
|
|
93
93
|
use_AdEMAMix: bool = False,
|
|
94
94
|
beta3_ema: float = 0.9999,
|
|
@@ -96,7 +96,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
96
96
|
t_alpha: int | None = None,
|
|
97
97
|
Simplified_AdEMAMix: bool = False,
|
|
98
98
|
alpha_grad: float = 100.0,
|
|
99
|
-
|
|
99
|
+
nnmf_factor: bool = False,
|
|
100
100
|
):
|
|
101
101
|
if not (lr >= 0.0):
|
|
102
102
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -106,17 +106,17 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
106
106
|
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
107
107
|
if not (weight_decay >= 0.0):
|
|
108
108
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
109
|
-
if
|
|
110
|
-
print("Warning:
|
|
111
|
-
|
|
109
|
+
if cautious_mask and grams_moment:
|
|
110
|
+
print("Warning: cautious is incompatible with grams, Disabling cautious.")
|
|
111
|
+
cautious_mask = False
|
|
112
112
|
if betas[0] == 0.0 and Simplified_AdEMAMix:
|
|
113
113
|
raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
|
|
114
114
|
if use_AdEMAMix and Simplified_AdEMAMix:
|
|
115
115
|
print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
|
|
116
|
-
if
|
|
117
|
-
print("Warning:
|
|
118
|
-
if
|
|
119
|
-
print("Warning:
|
|
116
|
+
if grams_moment and Simplified_AdEMAMix:
|
|
117
|
+
print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
|
|
118
|
+
if cautious_mask and Simplified_AdEMAMix:
|
|
119
|
+
print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
|
|
120
120
|
if use_atan2 and Simplified_AdEMAMix:
|
|
121
121
|
print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
|
|
122
122
|
use_atan2 = False
|
|
@@ -129,12 +129,12 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
129
129
|
self.clip_lambda = clip_lambda
|
|
130
130
|
self.stochastic_rounding = stochastic_rounding
|
|
131
131
|
self.use_atan2 = use_atan2 and not Simplified_AdEMAMix
|
|
132
|
-
self.
|
|
133
|
-
self.
|
|
132
|
+
self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
|
|
133
|
+
self.grams_moment = grams_moment and not Simplified_AdEMAMix
|
|
134
134
|
self.use_orthograd = use_orthograd
|
|
135
135
|
self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
|
|
136
136
|
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
137
|
-
self.factored =
|
|
137
|
+
self.factored = nnmf_factor
|
|
138
138
|
super().__init__(params, defaults)
|
|
139
139
|
|
|
140
140
|
@property
|
|
@@ -176,7 +176,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
176
176
|
# m_0 = 0
|
|
177
177
|
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
178
178
|
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
179
|
-
if not self.
|
|
179
|
+
if not self.grams_moment:
|
|
180
180
|
packed_d2 = (d2 + 7) // 8
|
|
181
181
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
182
182
|
if self.use_AdEMAMix:
|
|
@@ -220,7 +220,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
220
220
|
|
|
221
221
|
# Reconstruct m_{t-1}
|
|
222
222
|
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
223
|
-
if not self.
|
|
223
|
+
if not self.grams_moment:
|
|
224
224
|
if state['sign'].dtype != torch.uint8:
|
|
225
225
|
state['sign'] = state['sign'].to(torch.uint8)
|
|
226
226
|
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
@@ -257,9 +257,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
257
257
|
mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
|
|
258
258
|
else:
|
|
259
259
|
mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
260
|
-
if self.
|
|
260
|
+
if self.grams_moment:
|
|
261
261
|
mt = grad_reshaped.sign() * mt.abs()
|
|
262
|
-
elif self.
|
|
262
|
+
elif self.cautious_mask:
|
|
263
263
|
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
264
264
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
265
265
|
mt.mul_(mask)
|
|
@@ -284,7 +284,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
284
284
|
del grad_reshaped
|
|
285
285
|
|
|
286
286
|
# Compress and store new factors
|
|
287
|
-
if not self.
|
|
287
|
+
if not self.grams_moment:
|
|
288
288
|
state['sign'] = _pack_bools(mt > 0)
|
|
289
289
|
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
290
290
|
del mt
|
|
@@ -322,9 +322,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
322
322
|
else:
|
|
323
323
|
m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
324
324
|
|
|
325
|
-
if self.
|
|
325
|
+
if self.grams_moment:
|
|
326
326
|
m = grad.sign() * m.abs()
|
|
327
|
-
elif self.
|
|
327
|
+
elif self.cautious_mask:
|
|
328
328
|
mask = (m * grad > 0).to(grad.dtype)
|
|
329
329
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
330
330
|
m.mul_(mask)
|
|
@@ -26,12 +26,12 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
26
26
|
matrices to apply low-rank compression (default: True).
|
|
27
27
|
stochastic_rounding (bool, optional): whether to use stochastic
|
|
28
28
|
rounding for BF16 parameter updates (default: True).
|
|
29
|
-
|
|
29
|
+
cautious_mask (bool): whether to use the cautious masking technique. (default: False).
|
|
30
30
|
clip_threshold (float, optional): whether to clip the gradients norm
|
|
31
31
|
per-parameter as proposed in the paper `Lions and Muons: Optimization via
|
|
32
32
|
Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
|
|
33
33
|
(default: 0.0).
|
|
34
|
-
|
|
34
|
+
nnmf_factor (bool): whether to use the factorization or use the
|
|
35
35
|
uncompressed optimizer. (default: True)
|
|
36
36
|
d0 (float):
|
|
37
37
|
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
@@ -61,9 +61,9 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
61
61
|
vector_reshape: bool = True,
|
|
62
62
|
stochastic_rounding: bool = True,
|
|
63
63
|
use_orthograd: bool = False,
|
|
64
|
-
|
|
64
|
+
cautious_mask: bool = False,
|
|
65
65
|
clip_threshold: float = 0.0,
|
|
66
|
-
|
|
66
|
+
nnmf_factor: bool = True,
|
|
67
67
|
# prodigy parameters
|
|
68
68
|
beta3: float = None,
|
|
69
69
|
d0: float = 1e-6,
|
|
@@ -92,8 +92,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
92
92
|
fsdp_in_use=fsdp_in_use,
|
|
93
93
|
)
|
|
94
94
|
self.stochastic_rounding = stochastic_rounding
|
|
95
|
-
self.
|
|
96
|
-
self.factored =
|
|
95
|
+
self.cautious_mask = cautious_mask
|
|
96
|
+
self.factored = nnmf_factor
|
|
97
97
|
self.fsdp_in_use = fsdp_in_use
|
|
98
98
|
super().__init__(params, defaults)
|
|
99
99
|
# Global state for accumulating metrics across parameter updates within a single step.
|
|
@@ -197,7 +197,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
197
197
|
# Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
|
|
198
198
|
signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=(1-self.beta1)).sign_()
|
|
199
199
|
|
|
200
|
-
if self.
|
|
200
|
+
if self.cautious_mask:
|
|
201
201
|
mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
202
202
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
203
203
|
signed_update.mul_(mask)
|
|
@@ -224,7 +224,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
224
224
|
exp_avg = exp_avg.float()
|
|
225
225
|
signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=(1-self.beta1)).sign_()
|
|
226
226
|
|
|
227
|
-
if self.
|
|
227
|
+
if self.cautious_mask:
|
|
228
228
|
mask = (signed_update * grad > 0).to(grad.dtype)
|
|
229
229
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
230
230
|
signed_update.mul_(mask)
|
|
@@ -26,12 +26,12 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
26
26
|
matrices to apply low-rank compression (default: True).
|
|
27
27
|
stochastic_rounding (bool, optional): whether to use stochastic
|
|
28
28
|
rounding for BF16 parameter updates (default: True).
|
|
29
|
-
|
|
29
|
+
cautious_mask (bool): whether to use the cautious masking technique. (default: False).
|
|
30
30
|
clip_threshold (float, optional): whether to clip the gradients norm
|
|
31
31
|
per-parameter as proposed in the paper `Lions and Muons: Optimization via
|
|
32
32
|
Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
|
|
33
33
|
(default: 0.0).
|
|
34
|
-
|
|
34
|
+
nnmf_factor (bool): whether to use the factorization or use the
|
|
35
35
|
uncompressed optimizer. (default: True)
|
|
36
36
|
"""
|
|
37
37
|
|
|
@@ -44,9 +44,9 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
44
44
|
vector_reshape: bool = True,
|
|
45
45
|
stochastic_rounding: bool = True,
|
|
46
46
|
use_orthograd: bool = False,
|
|
47
|
-
|
|
47
|
+
cautious_mask: bool = False,
|
|
48
48
|
clip_threshold: float = 0.0,
|
|
49
|
-
|
|
49
|
+
nnmf_factor: bool = True,
|
|
50
50
|
):
|
|
51
51
|
if not lr > 0.0:
|
|
52
52
|
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
@@ -64,8 +64,8 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
64
64
|
clip_threshold=clip_threshold,
|
|
65
65
|
)
|
|
66
66
|
self.stochastic_rounding = stochastic_rounding
|
|
67
|
-
self.
|
|
68
|
-
self.factored =
|
|
67
|
+
self.cautious_mask = cautious_mask
|
|
68
|
+
self.factored = nnmf_factor
|
|
69
69
|
super().__init__(params, defaults)
|
|
70
70
|
|
|
71
71
|
@property
|
|
@@ -140,7 +140,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
140
140
|
# Compute update term c_t
|
|
141
141
|
signed_update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
|
|
142
142
|
|
|
143
|
-
if self.
|
|
143
|
+
if self.cautious_mask:
|
|
144
144
|
mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
145
145
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
146
146
|
signed_update.mul_(mask)
|
|
@@ -167,7 +167,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
167
167
|
exp_avg = exp_avg.float()
|
|
168
168
|
signed_update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
|
|
169
169
|
|
|
170
|
-
if self.
|
|
170
|
+
if self.cautious_mask:
|
|
171
171
|
mask = (signed_update * grad > 0).to(grad.dtype)
|
|
172
172
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
173
173
|
signed_update.mul_(mask)
|