adv-optm 2.3.dev1__tar.gz → 2.3.dev3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/PKG-INFO +1 -1
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/__init__.py +1 -1
- adv_optm-2.3.dev3/adv_optm/stiefel_optm/Stiefel_LoRA.py +231 -0
- adv_optm-2.3.dev3/adv_optm/stiefel_optm/__init__.py +5 -0
- adv_optm-2.3.dev3/adv_optm/stiefel_optm/stiefel_util.py +149 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm.egg-info/SOURCES.txt +3 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/setup.py +1 -1
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/LICENSE +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/README.md +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/AdaMuon_adv.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/Muon_adv.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/SignSGD_adv.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/param_update.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/setup.cfg +0 -0
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from ..util import param_update
|
|
6
|
+
from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
|
|
7
|
+
from . import stiefel_util
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Stiefel_LoRA(torch.optim.Optimizer):
|
|
11
|
+
"""
|
|
12
|
+
Implements an advanced Stiefel_LoRA algorithm.
|
|
13
|
+
|
|
14
|
+
based on the paper:
|
|
15
|
+
""
|
|
16
|
+
In disguise it's modified SignSGD with momentum (SignUM).
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
20
|
+
parameter groups.
|
|
21
|
+
lr (float, optional): learning rate (default: 1e-4).
|
|
22
|
+
momentum (float, optional): coefficients for computing
|
|
23
|
+
running average of the gradients (default: 0.9).
|
|
24
|
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
|
|
25
|
+
cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
|
|
26
|
+
applied only to parameter coordinates where the sign of the parameter
|
|
27
|
+
and the sign of the optimizer update align (default: False).
|
|
28
|
+
vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
|
|
29
|
+
matrices to apply low-rank compression (default: True).
|
|
30
|
+
stochastic_rounding (bool, optional): whether to use stochastic
|
|
31
|
+
rounding for BF16 parameter updates (default: True).
|
|
32
|
+
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
33
|
+
This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
|
|
34
|
+
more responsive, especially for small batch sizes. (default: False)
|
|
35
|
+
alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
|
|
36
|
+
(only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
|
|
37
|
+
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
38
|
+
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
39
|
+
stability. (default: 100.0)
|
|
40
|
+
nnmf_factor (bool): whether to use the factorization or use the
|
|
41
|
+
uncompressed optimizer. (default: True)
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
params,
|
|
47
|
+
lr: float = 1e-4,
|
|
48
|
+
momentum: float = 0.9,
|
|
49
|
+
# Decoupled/cautious weight decay
|
|
50
|
+
weight_decay: float = 0.0,
|
|
51
|
+
cautious_wd: bool = False,
|
|
52
|
+
# Stochastic Rounding for BF16
|
|
53
|
+
stochastic_rounding: bool = True,
|
|
54
|
+
# Simplified_AdEMAMix
|
|
55
|
+
alpha_grad: float = 1.0,
|
|
56
|
+
Simplified_AdEMAMix: bool = False,
|
|
57
|
+
# SMMF factorization
|
|
58
|
+
nnmf_factor: bool = False,
|
|
59
|
+
vector_reshape: bool = False,
|
|
60
|
+
# torch.compile
|
|
61
|
+
compiled_optimizer: bool = False,
|
|
62
|
+
):
|
|
63
|
+
if not lr > 0.0:
|
|
64
|
+
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
65
|
+
if not 0.0 <= momentum <= 1.0:
|
|
66
|
+
raise ValueError(f"momentum should be in [0.0, 1.0], but got {momentum}")
|
|
67
|
+
if not weight_decay >= 0.0:
|
|
68
|
+
raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
|
|
69
|
+
|
|
70
|
+
defaults = dict(
|
|
71
|
+
lr=lr,
|
|
72
|
+
momentum=momentum,
|
|
73
|
+
weight_decay=weight_decay,
|
|
74
|
+
cautious_wd=cautious_wd,
|
|
75
|
+
vector_reshape=vector_reshape,
|
|
76
|
+
alpha_grad=alpha_grad,
|
|
77
|
+
Simplified_AdEMAMix=Simplified_AdEMAMix,
|
|
78
|
+
nnmf_factor=nnmf_factor,
|
|
79
|
+
)
|
|
80
|
+
self.stochastic_rounding = stochastic_rounding
|
|
81
|
+
self._init_lr = lr
|
|
82
|
+
|
|
83
|
+
super().__init__(params, defaults)
|
|
84
|
+
|
|
85
|
+
if self.stochastic_rounding:
|
|
86
|
+
# For deterministic stochastic rounding, we need to seed the generator
|
|
87
|
+
# for each device used by the parameters.
|
|
88
|
+
devices = {p.device for group in self.param_groups for p in group['params'] if p.dtype == torch.bfloat16}
|
|
89
|
+
for device in devices:
|
|
90
|
+
param_update.set_seed(device)
|
|
91
|
+
|
|
92
|
+
# Initialize compiled function
|
|
93
|
+
self._compiled_step_parameter = None
|
|
94
|
+
if compiled_optimizer:
|
|
95
|
+
self.compile(fullgraph=True)
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def supports_fused_back_pass(self) -> bool:
|
|
99
|
+
return True
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def supports_memory_efficient_fp16(self) -> bool:
|
|
103
|
+
return True
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def supports_flat_params(self) -> bool:
|
|
107
|
+
return False
|
|
108
|
+
|
|
109
|
+
@torch.no_grad()
|
|
110
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
|
|
111
|
+
"""Performs a single optimization step on a single parameter."""
|
|
112
|
+
if p.grad is None:
|
|
113
|
+
return
|
|
114
|
+
|
|
115
|
+
grad = p.grad
|
|
116
|
+
state = self.state[p]
|
|
117
|
+
|
|
118
|
+
# State Initialization
|
|
119
|
+
if group["momentum"] > 0 and len(state) == 0:
|
|
120
|
+
state['factored'] = (
|
|
121
|
+
group['nnmf_factor'] and
|
|
122
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
dtype = torch.float32 if state['factored'] else p.dtype
|
|
126
|
+
|
|
127
|
+
if state['factored']:
|
|
128
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
129
|
+
d1, d2 = state['effective_shape']
|
|
130
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
131
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
132
|
+
packed_d2 = (d2 + 7) // 8
|
|
133
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
134
|
+
else:
|
|
135
|
+
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
136
|
+
|
|
137
|
+
lr = group["lr"]
|
|
138
|
+
|
|
139
|
+
random_int_tensor = None
|
|
140
|
+
|
|
141
|
+
if group.get('compiled_optimizer', False):
|
|
142
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
143
|
+
# Pre-generate random tensor for stochastic rounding if needed.
|
|
144
|
+
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
145
|
+
lr = torch.as_tensor(lr, dtype=torch.float64)
|
|
146
|
+
step_param_fn = self._compiled_step_parameter
|
|
147
|
+
else:
|
|
148
|
+
step_param_fn = self._step_parameter
|
|
149
|
+
|
|
150
|
+
step_param_fn(p, grad, state, group, lr, random_int_tensor)
|
|
151
|
+
|
|
152
|
+
def _step_parameter(self, p, grad, state, group, lr, random_int_tensor):
|
|
153
|
+
if grad.dtype != torch.float32 and state['factored']:
|
|
154
|
+
grad = grad.float()
|
|
155
|
+
|
|
156
|
+
is_stiefel, is_stiefel_euclidean, is_scale = stiefel_util.set_flags_AB(p)
|
|
157
|
+
|
|
158
|
+
momentum = group["momentum"]
|
|
159
|
+
Simplified_AdEMAMix = group["Simplified_AdEMAMix"]
|
|
160
|
+
alpha_grad = group["alpha_grad"]
|
|
161
|
+
|
|
162
|
+
if state.get('factored'):
|
|
163
|
+
# Factored Path
|
|
164
|
+
d1, d2 = state['effective_shape']
|
|
165
|
+
grad_reshaped = grad.view(d1, d2)
|
|
166
|
+
|
|
167
|
+
if momentum > 0:
|
|
168
|
+
# Reconstruct momentum m_{t-1}
|
|
169
|
+
exp_avg = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True)
|
|
170
|
+
exp_avg.mul_(momentum).add_(grad_reshaped)
|
|
171
|
+
|
|
172
|
+
if Simplified_AdEMAMix:
|
|
173
|
+
raw_update = exp_avg + (grad_reshaped * alpha_grad)
|
|
174
|
+
else:
|
|
175
|
+
raw_update = exp_avg.clone()
|
|
176
|
+
|
|
177
|
+
# Compress new momentum m_t and store factors
|
|
178
|
+
state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(exp_avg, signed=True)
|
|
179
|
+
else:
|
|
180
|
+
raw_update = grad_reshaped.clone()
|
|
181
|
+
|
|
182
|
+
raw_update = raw_update.view(p.shape)
|
|
183
|
+
|
|
184
|
+
else:
|
|
185
|
+
# Fallback to standard SignSGD logic
|
|
186
|
+
if momentum > 0:
|
|
187
|
+
exp_avg = state["exp_avg"]
|
|
188
|
+
exp_avg.mul_(momentum).add_(grad)
|
|
189
|
+
|
|
190
|
+
if Simplified_AdEMAMix:
|
|
191
|
+
raw_update = exp_avg + (grad * alpha_grad)
|
|
192
|
+
else:
|
|
193
|
+
raw_update = exp_avg.clone()
|
|
194
|
+
else:
|
|
195
|
+
raw_update = grad.clone()
|
|
196
|
+
|
|
197
|
+
if is_stiefel:
|
|
198
|
+
update = raw_update
|
|
199
|
+
else:
|
|
200
|
+
update = raw_update.sign_()
|
|
201
|
+
|
|
202
|
+
if not is_stiefel:
|
|
203
|
+
# The RMS of a sign() tensor is 1.
|
|
204
|
+
# We scale it to 0.2 to match the approx RMS of AdamW
|
|
205
|
+
update = update.mul_(lr * 0.2)
|
|
206
|
+
|
|
207
|
+
stiefel_util.apply_stiefel_update(
|
|
208
|
+
self, p, group, update, lr,
|
|
209
|
+
random_int_tensor=random_int_tensor,
|
|
210
|
+
is_B=is_stiefel,
|
|
211
|
+
is_A=is_stiefel_euclidean,
|
|
212
|
+
is_scale=is_scale
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
def compile(self, *args, **kwargs):
|
|
216
|
+
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
217
|
+
|
|
218
|
+
@torch.no_grad()
|
|
219
|
+
def step(self, closure: Optional[callable] = None):
|
|
220
|
+
"""Performs a single optimization step."""
|
|
221
|
+
loss = None
|
|
222
|
+
if closure is not None:
|
|
223
|
+
with torch.enable_grad():
|
|
224
|
+
loss = closure()
|
|
225
|
+
|
|
226
|
+
for group in self.param_groups:
|
|
227
|
+
for i, p in enumerate(group["params"]):
|
|
228
|
+
if p.grad is not None:
|
|
229
|
+
self.step_parameter(p, group, i)
|
|
230
|
+
|
|
231
|
+
return loss
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
|
|
5
|
+
from typing import Dict, Any
|
|
6
|
+
|
|
7
|
+
def trangent_proj(p, update, lr):
|
|
8
|
+
"""
|
|
9
|
+
[Stiefel-LoRA] Step 1: Tangent Space Projection
|
|
10
|
+
Formula: update = update - p @ sym(p.T @ update)
|
|
11
|
+
"""
|
|
12
|
+
pt_u = torch.matmul(p.t(), update)
|
|
13
|
+
sym_pt_u = 0.5 * (pt_u + pt_u.t())
|
|
14
|
+
# Project the update onto the tangent space
|
|
15
|
+
update.sub_(torch.matmul(p, sym_pt_u))
|
|
16
|
+
del pt_u, sym_pt_u
|
|
17
|
+
rms_rescaling(update, lr)
|
|
18
|
+
return update
|
|
19
|
+
|
|
20
|
+
def qr_retraction(p):
|
|
21
|
+
"""[Stiefel-LoRA] Step 2: Manifold Retraction (QR Decomposition)"""
|
|
22
|
+
Q, R = torch.linalg.qr(p)
|
|
23
|
+
d = R.diagonal().sign_()
|
|
24
|
+
Q *= d
|
|
25
|
+
return p.copy_(Q)
|
|
26
|
+
|
|
27
|
+
def rms_rescaling(update, lr):
|
|
28
|
+
"""Rescales update to have RMS (Root Mean Square) of 0.2."""
|
|
29
|
+
# RMS = sqrt(mean(update**2))
|
|
30
|
+
# Frobenius Norm = sqrt(sum(update**2))
|
|
31
|
+
# Relationship: Norm = RMS * sqrt(numel)
|
|
32
|
+
numel = update.numel()
|
|
33
|
+
norm = torch.linalg.vector_norm(update).clamp_min_(1e-12)
|
|
34
|
+
target_norm = lr * (numel ** 0.5) * 0.2
|
|
35
|
+
return update.mul_(target_norm / norm)
|
|
36
|
+
|
|
37
|
+
def set_flags_AB(p):
|
|
38
|
+
"""
|
|
39
|
+
Identify if parameter is LoRA A, B, or a scale parameter.
|
|
40
|
+
"""
|
|
41
|
+
if getattr(p, '_is_dora_scale', False):
|
|
42
|
+
return False, False, True
|
|
43
|
+
if getattr(p, '_is_lora_B', False):
|
|
44
|
+
return True, False, False
|
|
45
|
+
if getattr(p, '_is_lora_A', False):
|
|
46
|
+
return False, True, False
|
|
47
|
+
|
|
48
|
+
# Fallback heuristic (handles 4D Conv2d layers properly)
|
|
49
|
+
dim0 = p.shape[0]
|
|
50
|
+
dim1 = p.shape[1] if p.ndim > 1 else 1
|
|
51
|
+
|
|
52
|
+
is_scale = p.ndim == 1 or (p.ndim == 2 and (dim0 == 1 or dim1 == 1))
|
|
53
|
+
if is_scale:
|
|
54
|
+
return False, False, True
|
|
55
|
+
|
|
56
|
+
B = dim0 > dim1
|
|
57
|
+
A = dim0 < dim1
|
|
58
|
+
return B, A, False
|
|
59
|
+
|
|
60
|
+
def apply_stiefel_update(
|
|
61
|
+
self,
|
|
62
|
+
p: Tensor,
|
|
63
|
+
group: Dict[str, Any],
|
|
64
|
+
update: Tensor,
|
|
65
|
+
lr: float | Tensor,
|
|
66
|
+
wd: float | None = None,
|
|
67
|
+
random_int_tensor: Tensor | None = None,
|
|
68
|
+
is_B: bool | None = None,
|
|
69
|
+
is_A: bool | None = None,
|
|
70
|
+
is_scale: bool | None = False,
|
|
71
|
+
) -> None:
|
|
72
|
+
from ..util.param_update import _copy_stochastic_core_, copy_stochastic_
|
|
73
|
+
wd = group["weight_decay"] if wd is None else wd
|
|
74
|
+
cautious = group.get('cautious_wd', False)
|
|
75
|
+
|
|
76
|
+
if is_B or p.ndim == 1 or is_scale:
|
|
77
|
+
# Disable weight decay for the ortho matrix B or DoRA norm
|
|
78
|
+
wd = 0
|
|
79
|
+
|
|
80
|
+
if is_A:
|
|
81
|
+
# For matrix A, normalize weight decay by rank to make it invariant
|
|
82
|
+
wd = wd / p.shape[0]
|
|
83
|
+
|
|
84
|
+
scaled_wd = wd * (lr / self._init_lr)
|
|
85
|
+
|
|
86
|
+
# Compute full update in float32 if using bfloat16 with stochastic rounding
|
|
87
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
88
|
+
p_fp32 = p.float()
|
|
89
|
+
update_fp32 = update.float()
|
|
90
|
+
|
|
91
|
+
if is_B:
|
|
92
|
+
update_fp32 = trangent_proj(p_fp32, update_fp32, lr)
|
|
93
|
+
p_fp32.add_(-update_fp32)
|
|
94
|
+
p_fp32 = qr_retraction(p_fp32)
|
|
95
|
+
if random_int_tensor is not None:
|
|
96
|
+
_copy_stochastic_core_(p, p_fp32, random_int_tensor)
|
|
97
|
+
del random_int_tensor
|
|
98
|
+
else:
|
|
99
|
+
copy_stochastic_(p, p_fp32)
|
|
100
|
+
del p_fp32, update_fp32
|
|
101
|
+
return
|
|
102
|
+
|
|
103
|
+
# Apply weight decay if needed
|
|
104
|
+
if wd != 0:
|
|
105
|
+
if cautious:
|
|
106
|
+
# Cautious Weight Decay
|
|
107
|
+
mask = (update_fp32 * p_fp32 >= 0).float()
|
|
108
|
+
p_fp32.addcmul_(p_fp32, mask, value=-scaled_wd)
|
|
109
|
+
del mask
|
|
110
|
+
else:
|
|
111
|
+
# Standard decoupled weight decay
|
|
112
|
+
p_fp32.add_(p_fp32, alpha=-scaled_wd)
|
|
113
|
+
|
|
114
|
+
# Apply main update
|
|
115
|
+
p_fp32.add_(-update_fp32)
|
|
116
|
+
|
|
117
|
+
# Single stochastic rounding at the end
|
|
118
|
+
if random_int_tensor is not None:
|
|
119
|
+
# Compiled path: use the pre-computed random tensor
|
|
120
|
+
_copy_stochastic_core_(p, p_fp32, random_int_tensor)
|
|
121
|
+
del random_int_tensor
|
|
122
|
+
else:
|
|
123
|
+
# Uncompiled path: generate randoms inside
|
|
124
|
+
copy_stochastic_(p, p_fp32)
|
|
125
|
+
del p_fp32, update_fp32
|
|
126
|
+
|
|
127
|
+
else:
|
|
128
|
+
# Standard path for non-bfloat16 or without stochastic rounding
|
|
129
|
+
|
|
130
|
+
if is_B:
|
|
131
|
+
update = trangent_proj(p, update, lr)
|
|
132
|
+
p.add_(-update)
|
|
133
|
+
p = qr_retraction(p)
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
if wd != 0:
|
|
137
|
+
if cautious:
|
|
138
|
+
# Cautious Weight Decay
|
|
139
|
+
mask = (update * p >= 0).to(p.dtype)
|
|
140
|
+
p.addcmul_(p, mask, value=-scaled_wd)
|
|
141
|
+
del mask
|
|
142
|
+
else:
|
|
143
|
+
# Standard decoupled weight decay
|
|
144
|
+
p.add_(p, alpha=-scaled_wd)
|
|
145
|
+
|
|
146
|
+
# Apply main update
|
|
147
|
+
p.add_(-update)
|
|
148
|
+
|
|
149
|
+
del update
|
|
@@ -17,6 +17,9 @@ adv_optm/optim/Prodigy_adv.py
|
|
|
17
17
|
adv_optm/optim/SignSGD_adv.py
|
|
18
18
|
adv_optm/optim/Simplified_AdEMAMix.py
|
|
19
19
|
adv_optm/optim/__init__.py
|
|
20
|
+
adv_optm/stiefel_optm/Stiefel_LoRA.py
|
|
21
|
+
adv_optm/stiefel_optm/__init__.py
|
|
22
|
+
adv_optm/stiefel_optm/stiefel_util.py
|
|
20
23
|
adv_optm/util/Kourkoutas.py
|
|
21
24
|
adv_optm/util/Muon_AuxAdam.py
|
|
22
25
|
adv_optm/util/Muon_util.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|