adv-optm 2.4.dev12__tar.gz → 2.4.dev13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/PKG-INFO +1 -1
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/__init__.py +1 -1
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/AdamW_adv.py +15 -5
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/SinkSGD_adv.py +19 -10
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/sinkhorn.py +8 -20
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/setup.py +1 -1
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/LICENSE +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/README.md +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/AdaMuon_adv.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/Muon_adv.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/SignSGD_adv.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/param_update.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/scaled_optm.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/signed_util.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/state_util.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/setup.cfg +0 -0
|
@@ -232,12 +232,13 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
232
232
|
def supports_flat_params(self):
|
|
233
233
|
return False
|
|
234
234
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
235
|
+
def init_step(self):
|
|
236
|
+
for group in self.param_groups:
|
|
237
|
+
for i, p in enumerate(group['params']):
|
|
238
|
+
self.__init_state(p, group)
|
|
239
239
|
|
|
240
|
-
|
|
240
|
+
@torch.no_grad()
|
|
241
|
+
def __init_state(self, p, group):
|
|
241
242
|
state = self.state[p]
|
|
242
243
|
|
|
243
244
|
# State Initialization
|
|
@@ -303,6 +304,15 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
303
304
|
|
|
304
305
|
_init_fisher_wd_scaler(group, state, p)
|
|
305
306
|
|
|
307
|
+
@torch.no_grad()
|
|
308
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
309
|
+
if p.grad is None:
|
|
310
|
+
return
|
|
311
|
+
|
|
312
|
+
grad = p.grad
|
|
313
|
+
state = self.state[p]
|
|
314
|
+
self.__init_state(p, group)
|
|
315
|
+
|
|
306
316
|
beta1, beta2 = group['betas']
|
|
307
317
|
|
|
308
318
|
current_step = state['step']
|
|
@@ -116,6 +116,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
116
116
|
for device in devices:
|
|
117
117
|
param_update.set_seed(device)
|
|
118
118
|
|
|
119
|
+
self.init_step()
|
|
120
|
+
|
|
119
121
|
self._compiled_step_parameter = None
|
|
120
122
|
if compiled_optimizer:
|
|
121
123
|
self.compile(fullgraph=True)
|
|
@@ -136,14 +138,14 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
136
138
|
def supports_flat_params(self):
|
|
137
139
|
return False
|
|
138
140
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
141
|
+
def init_step(self):
|
|
142
|
+
for group in self.param_groups:
|
|
143
|
+
for i, p in enumerate(group['params']):
|
|
144
|
+
self.__init_state(p, group)
|
|
143
145
|
|
|
144
|
-
|
|
146
|
+
@torch.no_grad()
|
|
147
|
+
def __init_state(self, p, group):
|
|
145
148
|
state = self.state[p]
|
|
146
|
-
|
|
147
149
|
# State Initialization
|
|
148
150
|
if 'step' not in state:
|
|
149
151
|
state['step'] = 0
|
|
@@ -180,6 +182,15 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
180
182
|
|
|
181
183
|
_init_anchor(p, state, group)
|
|
182
184
|
|
|
185
|
+
@torch.no_grad()
|
|
186
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
187
|
+
if p.grad is None:
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
grad = p.grad
|
|
191
|
+
state = self.state[p]
|
|
192
|
+
self.__init_state(p, group)
|
|
193
|
+
|
|
183
194
|
step_size = group['lr']
|
|
184
195
|
|
|
185
196
|
random_int_tensor = None
|
|
@@ -219,7 +230,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
219
230
|
|
|
220
231
|
if momentum != 0:
|
|
221
232
|
buf = _reconstruct_state((state['mu_b_nmf'], state['mv_b_nmf'], state['sign'], d2), signed=True)
|
|
222
|
-
buf.
|
|
233
|
+
buf.lerp_(grad_reshaped, 1 - momentum)
|
|
223
234
|
|
|
224
235
|
# Factorize updated buffer
|
|
225
236
|
state['mu_b_nmf'], state['mv_b_nmf'], state['sign'] = _factorize_state(buf.clone(), signed=True)
|
|
@@ -239,9 +250,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
239
250
|
|
|
240
251
|
if momentum != 0:
|
|
241
252
|
buf = get_state(state, 'momentum_buffer', actual_precision)
|
|
242
|
-
|
|
243
|
-
buf.mul_(momentum).add_(grad, alpha=1 - momentum)
|
|
244
|
-
|
|
253
|
+
buf.lerp_(grad, 1 - momentum)
|
|
245
254
|
|
|
246
255
|
set_state(state, 'momentum_buffer', buf, actual_precision, random_int_state_tensor)
|
|
247
256
|
|
|
@@ -36,13 +36,10 @@ def apply_sr_sinkhorn(update: torch.Tensor, p: torch.Tensor, ortho_project: bool
|
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
# Precompute scaling factors.
|
|
39
|
-
scale_first = m if scale_cond else n
|
|
40
|
-
scale_second = n if scale_cond else m
|
|
39
|
+
scale_first = math.sqrt(m if scale_cond else n)
|
|
40
|
+
scale_second = math.sqrt(n if scale_cond else m)
|
|
41
41
|
|
|
42
42
|
if ortho_project:
|
|
43
|
-
# Pre-compute squares for the mathematical trick in ortho_normed
|
|
44
|
-
target_norm_sq_first = scale_first ** 2
|
|
45
|
-
target_norm_sq_second = scale_second ** 2
|
|
46
43
|
param_2d = p.float().view(p.shape[0], -1)
|
|
47
44
|
p_norm_sq_dim = torch.sum(param_2d * param_2d, dim=dim, keepdim=True).add_(1e-30)
|
|
48
45
|
p_norm_sq_adim = torch.sum(param_2d * param_2d, dim=1-dim, keepdim=True).add_(1e-30)
|
|
@@ -53,23 +50,17 @@ def apply_sr_sinkhorn(update: torch.Tensor, p: torch.Tensor, ortho_project: bool
|
|
|
53
50
|
norm1 = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
|
|
54
51
|
update_2d.mul_(scale_first / norm1)
|
|
55
52
|
if ortho_project:
|
|
56
|
-
update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_dim, dim, scale_first
|
|
53
|
+
update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_dim, dim, scale_first)
|
|
57
54
|
|
|
58
55
|
# Second normalization step
|
|
59
56
|
norm2 = update_2d.norm(p=2, dim=1-dim, keepdim=True).clamp_min_(1e-12)
|
|
60
57
|
update_2d.mul_(scale_second / norm2)
|
|
61
58
|
if ortho_project:
|
|
62
|
-
update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_adim, 1-dim, scale_second
|
|
63
|
-
|
|
64
|
-
# Final step
|
|
65
|
-
norm1 = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
|
|
66
|
-
update_2d.mul_(scale_first / norm1)
|
|
67
|
-
if ortho_project:
|
|
68
|
-
update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_dim, dim, scale_first, target_norm_sq_first)
|
|
59
|
+
update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_adim, 1-dim, scale_second)
|
|
69
60
|
|
|
70
61
|
return update_2d.view(original_shape).to(original_dtype)
|
|
71
62
|
|
|
72
|
-
def ortho_normed(p_2d, update_2d, p_norm_sq, dim, target_norm
|
|
63
|
+
def ortho_normed(p_2d, update_2d, p_norm_sq, dim, target_norm):
|
|
73
64
|
"""
|
|
74
65
|
Projects the update to be orthogonal to p along 'dim' and restores the original norm.
|
|
75
66
|
"""
|
|
@@ -80,10 +71,7 @@ def ortho_normed(p_2d, update_2d, p_norm_sq, dim, target_norm, target_norm_sq):
|
|
|
80
71
|
# In-place subtraction: update_2d = update_2d - (proj * p_2d)
|
|
81
72
|
update_2d.addcmul_(proj, p_2d, value=-1.0)
|
|
82
73
|
|
|
83
|
-
# Magnitude Preservation
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
g_orth_norm_sq = (target_norm_sq - proj_norm_sq).clamp_min_(1e-30)
|
|
87
|
-
|
|
88
|
-
scale_factor = target_norm / torch.sqrt(g_orth_norm_sq)
|
|
74
|
+
# Magnitude Preservation
|
|
75
|
+
g_orth_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
|
|
76
|
+
scale_factor = target_norm / g_orth_norm
|
|
89
77
|
return update_2d.mul_(scale_factor)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|