adv-optm 2.3.dev2__tar.gz → 2.4.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/PKG-INFO +1 -1
  2. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/__init__.py +1 -7
  3. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/AdaMuon_adv.py +28 -0
  4. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/AdamW_adv.py +86 -27
  5. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/Adopt_adv.py +95 -33
  6. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/Lion_adv.py +80 -5
  7. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/Muon_adv.py +28 -0
  8. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/Prodigy_adv.py +74 -25
  9. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/SignSGD_adv.py +94 -6
  10. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/Simplified_AdEMAMix.py +62 -13
  11. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/util/Muon_AuxAdam.py +3 -0
  12. adv_optm-2.4.dev1/adv_optm/util/centered_decay.py +112 -0
  13. adv_optm-2.4.dev1/adv_optm/util/param_update.py +286 -0
  14. adv_optm-2.4.dev1/adv_optm/util/scaled_optm.py +137 -0
  15. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/util/update_util.py +3 -1
  16. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm.egg-info/PKG-INFO +1 -1
  17. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm.egg-info/SOURCES.txt +2 -0
  18. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/setup.py +1 -1
  19. adv_optm-2.3.dev2/adv_optm/util/param_update.py +0 -177
  20. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/LICENSE +0 -0
  21. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/README.md +0 -0
  22. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  23. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/__init__.py +0 -0
  24. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/util/Kourkoutas.py +0 -0
  25. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/util/Muon_util.py +0 -0
  26. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/util/OrthoGrad.py +0 -0
  27. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/util/__init__.py +0 -0
  28. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/util/factorization_util.py +0 -0
  29. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/util/lion_k.py +0 -0
  30. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm.egg-info/dependency_links.txt +0 -0
  31. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm.egg-info/requires.txt +0 -0
  32. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm.egg-info/top_level.txt +0 -0
  33. {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.3.dev2
3
+ Version: 2.4.dev1
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -10,11 +10,6 @@ from .optim import (
10
10
  SignSGD_adv,
11
11
  )
12
12
 
13
- from .stiefel_optm.Stiefel_LoRA import (
14
- Stiefel_LoRA,
15
- )
16
-
17
-
18
13
  __all__ = [
19
14
  "AdamW_adv",
20
15
  "Prodigy_adv",
@@ -25,7 +20,6 @@ __all__ = [
25
20
  "Muon_adv",
26
21
  "AdaMuon_adv",
27
22
  "SignSGD_adv",
28
- "Stiefel_LoRA",
29
23
  ]
30
24
 
31
- __version__ = "2.3.dev2"
25
+ __version__ = "2.4.dev1"
@@ -8,6 +8,7 @@ from ..util.factorization_util import _get_effective_shape, _factorize_state, _r
8
8
  from ..util.OrthoGrad import _orthogonalize_gradient
9
9
  from ..util.Kourkoutas import KourkoutasHelper
10
10
  from ..util import Muon_AuxAdam
11
+ from ..util.centered_decay import _init_anchor
11
12
 
12
13
  A = 4 / math.pi
13
14
 
@@ -87,6 +88,15 @@ class AdaMuon_adv(torch.optim.Optimizer):
87
88
  (default: False)
88
89
  mars_gamma (float): The scaling coefficient for MARS gradient correction.
89
90
  (default: 0.025)
91
+ centered_wd (float): Centered Weight Decay coefficient. Instead of decaying weights
92
+ toward zero, they are decayed toward their initial values (anchors). This
93
+ can be used together with standard weight decay. (default: 0.0)
94
+ centered_wd_mode (str): The quantization format used to store the anchor
95
+ weights to save VRAM. Options include:
96
+ 'full': Stores anchors in the original parameter's precision.
97
+ 'float8': Uses torch.float8_e4m3fn for a balance of precision and memory.
98
+ 'int8': Uses 8-bit block-wise quantization (block size 128).
99
+ 'int4': Uses 4-bit block-wise quantization (block size 32).
90
100
  nnmf_factor (bool): whether to use the factorization or disable it to use
91
101
  the uncompressed optimizer. (default: False)
92
102
  use_muon (bool | None): whether to use Muon or AuxAdamW. MUST be provided
@@ -157,6 +167,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
157
167
  # Spectral Normalization
158
168
  n_layers: int = 1,
159
169
  spectral_normalization: bool = False,
170
+ # Centered WD
171
+ centered_wd: float = 0.0,
172
+ centered_wd_mode: str = 'float8',
160
173
  # torch.compile
161
174
  compiled_optimizer: bool = False,
162
175
  # --- AdamW_adv specific parameters ---
@@ -214,6 +227,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
214
227
  "approx_mars": approx_mars, "mars_gamma": mars_gamma,
215
228
  # Spectral Normalization
216
229
  "n_layers": n_layers, "spectral_normalization": spectral_normalization,
230
+ # Centered WD
231
+ "centered_wd": centered_wd,
232
+ "centered_wd_mode": centered_wd_mode,
217
233
  # AdamW_adv defaults
218
234
  "adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
219
235
  "adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
@@ -261,6 +277,16 @@ class AdaMuon_adv(torch.optim.Optimizer):
261
277
  if compiled_optimizer:
262
278
  self.compile(fullgraph=True)
263
279
 
280
+ def load_state_dict(self, state_dict: dict) -> None:
281
+ """
282
+ Overrides default load_state_dict to implement a workaround for PyTorch's
283
+ automatic dtype casting. It ensures factorized states remain float32 for
284
+ stability, preserves integer/float8 quantized anchor states, and forces
285
+ standard states onto the parameter's current dtype/device.
286
+ """
287
+ super().load_state_dict(state_dict)
288
+ param_update.post_process_loaded_state(self)
289
+
264
290
  @property
265
291
  def supports_fused_back_pass(self):
266
292
  return True
@@ -344,6 +370,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
344
370
  # Note: This requires full-rank memory even if factored
345
371
  state['last_grad'] = torch.zeros_like(p, device=device, dtype=p.dtype)
346
372
 
373
+ _init_anchor(p, state, group)
374
+
347
375
  group['adam_kourkoutas_beta'] = False
348
376
  state['is_muon'] = True # Workaround as group was acting weirdly; passing muon params in adam path
349
377
 
@@ -9,6 +9,8 @@ from ..util.factorization_util import _get_effective_shape, _reconstruct_state,
9
9
  from ..util.update_util import _grams_update, _cautious_update
10
10
  from ..util.OrthoGrad import _orthogonalize_gradient
11
11
  from ..util.Kourkoutas import KourkoutasHelper
12
+ from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
13
+ from ..util.centered_decay import _init_anchor
12
14
 
13
15
  A = 4 / math.pi
14
16
 
@@ -78,8 +80,19 @@ class AdamW_adv(torch.optim.Optimizer):
78
80
  and returns a unique, hashable key representing its "layer" or "bucket".
79
81
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
80
82
  (default: None)
83
+ centered_wd (float): Centered Weight Decay coefficient. Instead of decaying weights
84
+ toward zero, they are decayed toward their initial values (anchors). This
85
+ can be used together with standard weight decay. (default: 0.0)
86
+ centered_wd_mode (str): The quantization format used to store the anchor
87
+ weights to save VRAM. Options include:
88
+ 'full': Stores anchors in the original parameter's precision.
89
+ 'float8': Uses torch.float8_e4m3fn for a balance of precision and memory.
90
+ 'int8': Uses 8-bit block-wise quantization (block size 128).
91
+ 'int4': Uses 4-bit block-wise quantization (block size 32).
81
92
  nnmf_factor (bool): whether to use the factorization or disable it to use
82
93
  the uncompressed optimizer. (default: False)
94
+ factored_2nd (bool): whether to keep the first moment uncompressed (dense)
95
+ while only factorizing the second moment. (default: True)
83
96
  """
84
97
 
85
98
  def __init__(
@@ -114,9 +127,15 @@ class AdamW_adv(torch.optim.Optimizer):
114
127
  k_warmup_steps: int = 0,
115
128
  k_logging: int = 0,
116
129
  layer_key_fn: Optional[Callable] = None,
130
+ # Scaled Optimizer
131
+ scaled_optm: bool = False,
132
+ # Centered WD
133
+ centered_wd: float = 0.0,
134
+ centered_wd_mode: str = 'float8',
117
135
  # SMMF factorization
118
136
  nnmf_factor: bool = False,
119
137
  vector_reshape: bool = False,
138
+ factored_2nd: bool = False,
120
139
  # torch.compile
121
140
  compiled_optimizer: bool = False,
122
141
  ):
@@ -137,12 +156,14 @@ class AdamW_adv(torch.optim.Optimizer):
137
156
 
138
157
  defaults = {
139
158
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
140
- "vector_reshape": vector_reshape, "use_atan2": use_atan2,
159
+ "use_atan2": use_atan2,
141
160
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
142
161
  "beta3_ema": beta3_ema, "alpha": alpha, "compiled_optimizer": compiled_optimizer,
143
162
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
144
163
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
145
- "nnmf_factor": nnmf_factor
164
+ "scaled_optm": scaled_optm,
165
+ "centered_wd": centered_wd, "centered_wd_mode": centered_wd_mode,
166
+ "nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd
146
167
  }
147
168
  self.stochastic_rounding = stochastic_rounding
148
169
  self.cautious_mask = cautious_mask
@@ -150,6 +171,7 @@ class AdamW_adv(torch.optim.Optimizer):
150
171
  self.use_AdEMAMix = use_AdEMAMix
151
172
  self.kourkoutas_beta = kourkoutas_beta
152
173
  self.layer_key_fn = layer_key_fn
174
+ self._init_lr = lr
153
175
  super().__init__(params, defaults)
154
176
 
155
177
  if self.kourkoutas_beta:
@@ -167,6 +189,16 @@ class AdamW_adv(torch.optim.Optimizer):
167
189
  if compiled_optimizer:
168
190
  self.compile(fullgraph=True)
169
191
 
192
+ def load_state_dict(self, state_dict: dict) -> None:
193
+ """
194
+ Overrides default load_state_dict to implement a workaround for PyTorch's
195
+ automatic dtype casting. It ensures factorized states remain float32 for
196
+ stability, preserves integer/float8 quantized anchor states, and forces
197
+ standard states onto the parameter's current dtype/device.
198
+ """
199
+ super().load_state_dict(state_dict)
200
+ param_update.post_process_loaded_state(self)
201
+
170
202
  @property
171
203
  def supports_fused_back_pass(self):
172
204
  return True
@@ -194,6 +226,7 @@ class AdamW_adv(torch.optim.Optimizer):
194
226
  state['factored'] = (
195
227
  group['nnmf_factor'] and
196
228
  not (len(p.shape) == 1 and not group['vector_reshape'])
229
+ or group["factored_2nd"]
197
230
  )
198
231
 
199
232
  dtype = torch.float32 if state['factored'] else p.dtype
@@ -203,18 +236,25 @@ class AdamW_adv(torch.optim.Optimizer):
203
236
  state['effective_shape'] = _get_effective_shape(p.numel())
204
237
  d1, d2 = state['effective_shape']
205
238
 
206
- # First moment (m)
207
- if group['betas'][0] > 0:
208
- state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
209
- state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
210
- packed_d2 = (d2 + 7) // 8
211
- state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
212
- # AdEMAMix slow moment (m_slow)
213
- if self.use_AdEMAMix:
214
- state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
215
- state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
216
- packed_d2 = (d2 + 7) // 8
217
- state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
239
+ if not group.get('factored_2nd', False):
240
+ # First moment (m)
241
+ if group['betas'][0] > 0:
242
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
243
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
244
+ packed_d2 = (d2 + 7) // 8
245
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
246
+ # AdEMAMix slow moment (m_slow)
247
+ if self.use_AdEMAMix:
248
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
249
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
250
+ packed_d2 = (d2 + 7) // 8
251
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
252
+ else:
253
+ if group['betas'][0] > 0:
254
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
255
+ if self.use_AdEMAMix:
256
+ state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
257
+
218
258
  # Second moment (v)
219
259
  state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
220
260
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
@@ -228,6 +268,11 @@ class AdamW_adv(torch.optim.Optimizer):
228
268
  # Second moment (v)
229
269
  state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
230
270
 
271
+ if group.get('scaled_optm', False) and is_spectral(p):
272
+ init_spectral_norm(group, state, p)
273
+
274
+ _init_anchor(p, state, group)
275
+
231
276
  beta1, beta2 = group['betas']
232
277
 
233
278
  current_step = state['step']
@@ -275,32 +320,42 @@ class AdamW_adv(torch.optim.Optimizer):
275
320
  # Accumulate current grad's norm for the *next* step
276
321
  self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
277
322
 
323
+ # Determine if we are using dense first-moments alongside a factored second-order second-moment
324
+ factored_2nd = group.get('factored_2nd', False)
325
+
278
326
  if state['factored']:
279
327
  d1, d2 = state['effective_shape']
280
328
  grad_reshaped = grad.view(d1, d2)
281
329
 
282
330
  # Reconstruct momentum from previous step's factors
283
331
  if beta1 > 0:
284
- mt = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True)
332
+ if factored_2nd:
333
+ mt = state['exp_avg'].view(d1, d2)
334
+ else:
335
+ mt = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True)
285
336
 
286
337
  # Update momentum in full-size
287
338
  mt.lerp_(grad_reshaped, 1.0 - beta1)
288
339
 
289
- # Factorize
290
- state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(mt.clone(), signed=True)
340
+ if not factored_2nd:
341
+ # Factorize
342
+ state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(mt.clone(), signed=True)
291
343
 
292
344
  if self.grams_moment:
293
- update_mt = _grams_update(mt, grad_reshaped, inplace=True)
345
+ update_mt = _grams_update(mt, grad_reshaped, inplace=not factored_2nd)
294
346
  elif self.cautious_mask:
295
- update_mt = _cautious_update(mt, grad_reshaped, inplace=True)
347
+ update_mt = _cautious_update(mt, grad_reshaped, inplace=not factored_2nd)
296
348
  else:
297
- update_mt = mt
349
+ update_mt = mt if not factored_2nd else mt.clone()
298
350
 
299
351
  vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
300
352
  vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
301
353
 
302
354
  if self.use_AdEMAMix:
303
- mt_slow = _reconstruct_state((state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'], d2), signed=True)
355
+ if factored_2nd:
356
+ mt_slow = state['exp_avg_slow'].view(d1, d2)
357
+ else:
358
+ mt_slow = _reconstruct_state((state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'], d2), signed=True)
304
359
 
305
360
  mt_slow.lerp_(grad_reshaped, 1.0 - beta3_ema)
306
361
 
@@ -308,9 +363,11 @@ class AdamW_adv(torch.optim.Optimizer):
308
363
  update = update_mt.add_(mt_slow, alpha=alpha)
309
364
  else:
310
365
  update = grad_reshaped.add(mt_slow, alpha=alpha)
311
- # Factorize
312
- state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
313
- del mt_slow
366
+
367
+ if not factored_2nd:
368
+ # Factorize
369
+ state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
370
+ del mt_slow
314
371
  else:
315
372
  if beta1 > 0:
316
373
  update = update_mt
@@ -330,8 +387,7 @@ class AdamW_adv(torch.optim.Optimizer):
330
387
  update.div_(denom)
331
388
  del vt
332
389
 
333
- update_scaling = step_size * A if group['use_atan2'] else step_size
334
- update = update.view(p.shape).mul_(update_scaling)
390
+ update = update.view(p.shape)
335
391
 
336
392
  else: # Standard AdamW logic for non-factored tensors
337
393
  if beta1 > 0:
@@ -369,7 +425,10 @@ class AdamW_adv(torch.optim.Optimizer):
369
425
  update.div_(denom)
370
426
  del denom
371
427
 
372
- update_scaling = step_size * A if group['use_atan2'] else step_size
428
+ update_scaling = step_size * A if group['use_atan2'] else step_size
429
+ if group.get('scaled_optm', False):
430
+ update = scale_update(p, update, update_scaling, vector_state=state.get('spectral_v'))
431
+ else:
373
432
  update.mul_(update_scaling)
374
433
 
375
434
  param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor)
@@ -8,6 +8,8 @@ from ..util.factorization_util import _get_effective_shape, _reconstruct_state,
8
8
  from ..util.OrthoGrad import _orthogonalize_gradient
9
9
  from ..util.Kourkoutas import KourkoutasHelper
10
10
  from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update
11
+ from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
12
+ from ..util.centered_decay import _init_anchor
11
13
 
12
14
  A = 4 / math.pi
13
15
 
@@ -94,8 +96,19 @@ class Adopt_adv(torch.optim.Optimizer):
94
96
  and returns a unique, hashable key representing its "layer" or "bucket".
95
97
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
96
98
  (default: None)
99
+ centered_wd (float): Centered Weight Decay coefficient. Instead of decaying weights
100
+ toward zero, they are decayed toward their initial values (anchors). This
101
+ can be used together with standard weight decay. (default: 0.0)
102
+ centered_wd_mode (str): The quantization format used to store the anchor
103
+ weights to save VRAM. Options include:
104
+ 'full': Stores anchors in the original parameter's precision.
105
+ 'float8': Uses torch.float8_e4m3fn for a balance of precision and memory.
106
+ 'int8': Uses 8-bit block-wise quantization (block size 128).
107
+ 'int4': Uses 4-bit block-wise quantization (block size 32).
97
108
  nnmf_factor (bool): whether to use the factorization or disable it to use
98
109
  the uncompressed optimizer. (default: False)
110
+ factored_2nd (bool): whether to keep the first moment uncompressed (dense)
111
+ while only factorizing the second moment. (default: True)
99
112
  """
100
113
 
101
114
  def __init__(
@@ -133,9 +146,15 @@ class Adopt_adv(torch.optim.Optimizer):
133
146
  k_warmup_steps: int = 0,
134
147
  k_logging: int = 0,
135
148
  layer_key_fn: Optional[Callable] = None,
149
+ # Scaled Optimizer
150
+ scaled_optm: bool = False,
151
+ # Centered WD
152
+ centered_wd: float = 0.0,
153
+ centered_wd_mode: str = 'float8',
136
154
  # SMMF factorization
137
155
  nnmf_factor: bool = False,
138
- vector_reshape: bool = False,
156
+ vector_reshape: bool = True,
157
+ factored_2nd: bool = False,
139
158
  # torch.compile
140
159
  compiled_optimizer: bool = False,
141
160
  ):
@@ -163,11 +182,14 @@ class Adopt_adv(torch.optim.Optimizer):
163
182
 
164
183
  defaults = {
165
184
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
166
- "vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
185
+ "beta3_ema": beta3_ema, "alpha": alpha,
167
186
  "alpha_grad": alpha_grad,
168
187
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
169
188
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
170
- "nnmf_factor": nnmf_factor,
189
+ "scaled_optm": scaled_optm,
190
+ "centered_wd": centered_wd,
191
+ "centered_wd_mode": centered_wd_mode,
192
+ "nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd,
171
193
  "compiled_optimizer": compiled_optimizer,
172
194
  }
173
195
  self.clip_lambda = clip_lambda
@@ -180,6 +202,7 @@ class Adopt_adv(torch.optim.Optimizer):
180
202
  self.Simplified_AdEMAMix = Simplified_AdEMAMix
181
203
  self.kourkoutas_beta = kourkoutas_beta
182
204
  self.layer_key_fn = layer_key_fn
205
+ self._init_lr = lr
183
206
  super().__init__(params, defaults)
184
207
 
185
208
  if self.kourkoutas_beta:
@@ -196,6 +219,16 @@ class Adopt_adv(torch.optim.Optimizer):
196
219
  if compiled_optimizer:
197
220
  self.compile(fullgraph=True)
198
221
 
222
+ def load_state_dict(self, state_dict: dict) -> None:
223
+ """
224
+ Overrides default load_state_dict to implement a workaround for PyTorch's
225
+ automatic dtype casting. It ensures factorized states remain float32 for
226
+ stability, preserves integer/float8 quantized anchor states, and forces
227
+ standard states onto the parameter's current dtype/device.
228
+ """
229
+ super().load_state_dict(state_dict)
230
+ param_update.post_process_loaded_state(self)
231
+
199
232
  @property
200
233
  def supports_fused_back_pass(self): return True
201
234
  @property
@@ -218,6 +251,7 @@ class Adopt_adv(torch.optim.Optimizer):
218
251
  state['factored'] = (
219
252
  group['nnmf_factor'] and
220
253
  not (len(p.shape) == 1 and not group['vector_reshape'])
254
+ or group["factored_2nd"]
221
255
  )
222
256
 
223
257
  dtype = torch.float32 if state['factored'] else p.dtype
@@ -226,18 +260,24 @@ class Adopt_adv(torch.optim.Optimizer):
226
260
  state['effective_shape'] = _get_effective_shape(p.numel())
227
261
  d1, d2 = state['effective_shape']
228
262
 
229
- # First moment (m)
230
- if group['betas'][0] > 0:
231
- state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
232
- state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
233
- packed_d2 = (d2 + 7) // 8
234
- state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
235
- # AdEMAMix slow moment (m_slow)
236
- if self.use_AdEMAMix:
237
- state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
238
- state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
239
- packed_d2 = (d2 + 7) // 8
240
- state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
263
+ if not group.get('factored_2nd', False):
264
+ # First moment (m)
265
+ if group['betas'][0] > 0:
266
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
267
+ state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
268
+ packed_d2 = (d2 + 7) // 8
269
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
270
+ # AdEMAMix slow moment (m_slow)
271
+ if self.use_AdEMAMix:
272
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
273
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
274
+ packed_d2 = (d2 + 7) // 8
275
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
276
+ else:
277
+ if group['betas'][0] > 0:
278
+ state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
279
+ if self.use_AdEMAMix:
280
+ state['exp_avg_slow'] = torch.zeros_like(p, device=p.device, dtype=dtype)
241
281
  # Second moment (v)
242
282
  vt_init = grad.to(dtype).view(d1, d2).square()
243
283
  # Allocate NMF factors for vt
@@ -253,6 +293,11 @@ class Adopt_adv(torch.optim.Optimizer):
253
293
  state['exp_avg_slow'] = torch.zeros_like(p, device=p.device, dtype=dtype)
254
294
  state['exp_avg_sq'] = grad.to(dtype).square()
255
295
 
296
+ if group.get('scaled_optm', False) and is_spectral(p):
297
+ init_spectral_norm(group, state, p)
298
+
299
+ _init_anchor(p, state, group)
300
+
256
301
  beta1, beta2 = group['betas']
257
302
 
258
303
  current_step = state['step']
@@ -280,7 +325,7 @@ class Adopt_adv(torch.optim.Optimizer):
280
325
  step_param_fn = self._step_parameter
281
326
 
282
327
  if self.Simplified_AdEMAMix:
283
- lr = _scale_sim_AdEMAMix_update(beta1, state['step'] + 1, group["alpha_grad"], lr)
328
+ lr = _scale_sim_AdEMAMix_update(beta1, state['step'] + 1, group["alpha_grad"], lr, group.get('scaled_optm', False))
284
329
 
285
330
  step_param_fn(p, grad, state, group, lr, beta1, beta2, random_int_tensor)
286
331
 
@@ -302,6 +347,9 @@ class Adopt_adv(torch.optim.Optimizer):
302
347
  # Accumulate current grad's norm for the *next* step
303
348
  self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
304
349
 
350
+ # Determine if we are using dense first-moments alongside a factored second-order second-moment
351
+ factored_2nd = group.get('factored_2nd', False)
352
+
305
353
  if state['factored']:
306
354
  d1, d2 = state['effective_shape']
307
355
  grad_reshaped = grad.view(d1, d2)
@@ -328,35 +376,47 @@ class Adopt_adv(torch.optim.Optimizer):
328
376
 
329
377
  # ADOPT Step B: Update momentum m_t using normalized gradient
330
378
  if beta1 > 0:
331
- # Reconstruct m_{t-1}
332
- mt = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True)
379
+ if factored_2nd:
380
+ mt = state['exp_avg'].view(d1, d2)
381
+ else:
382
+ # Reconstruct m_{t-1}
383
+ mt = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True)
384
+
333
385
  if self.Simplified_AdEMAMix:
334
386
  mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
335
387
  else:
336
388
  mt.lerp_(normalized_grad, 1.0 - beta1)
337
389
 
338
- # Factorize
339
- state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(mt.clone(), signed=True)
390
+ if not factored_2nd:
391
+ # Factorize
392
+ state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(mt.clone(), signed=True)
340
393
 
341
394
  if self.grams_moment:
342
- update_mt = _grams_update(mt, grad_reshaped, inplace=True)
395
+ update_mt = _grams_update(mt, grad_reshaped, inplace=not factored_2nd)
343
396
  elif self.cautious_mask:
344
- update_mt = _cautious_update(mt, grad_reshaped, inplace=True)
397
+ update_mt = _cautious_update(mt, grad_reshaped, inplace=not factored_2nd)
345
398
  else:
346
- update_mt = mt
399
+ update_mt = mt if not factored_2nd else mt.clone()
347
400
 
348
401
  if self.use_AdEMAMix:
349
- # Reconstruct AdEMAMix EMA
350
- mt_slow = _reconstruct_state((state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'], d2), signed=True)
402
+ if factored_2nd:
403
+ mt_slow = state['exp_avg_slow'].view(d1, d2)
404
+ else:
405
+ # Reconstruct AdEMAMix EMA
406
+ mt_slow = _reconstruct_state((state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'], d2), signed=True)
407
+
351
408
  mt_slow.lerp_(normalized_grad, 1.0 - beta3_ema)
409
+
352
410
  if beta1 > 0:
353
411
  update = update_mt.add_(mt_slow, alpha=alpha)
354
412
  del normalized_grad
355
413
  else:
356
414
  update = normalized_grad.add_(mt_slow, alpha=alpha)
357
- # Factorize
358
- state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
359
- del mt_slow
415
+ if not factored_2nd:
416
+ # Factorize
417
+ state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
418
+ del mt_slow
419
+
360
420
  elif self.Simplified_AdEMAMix:
361
421
  update = update_mt.add_(normalized_grad, alpha=alpha_grad)
362
422
  del normalized_grad
@@ -369,9 +429,6 @@ class Adopt_adv(torch.optim.Optimizer):
369
429
 
370
430
  update = update.view(p.shape)
371
431
 
372
- update_scaling = lr * A if self.use_atan2 else lr
373
- update.mul_(update_scaling)
374
-
375
432
  else: # Standard ADOPT logic for non-factored tensors
376
433
  vt = state['exp_avg_sq'] # v_{t-1}
377
434
 
@@ -418,12 +475,17 @@ class Adopt_adv(torch.optim.Optimizer):
418
475
  else:
419
476
  update = normalized_grad
420
477
 
421
- update_scaling = lr * A if self.use_atan2 else lr
422
- update.mul_(update_scaling)
423
478
 
424
479
  # Update second moment v_t for the next step using raw g_t
425
480
  vt.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
426
481
 
482
+ update_scaling = lr * A if self.use_atan2 else lr
483
+
484
+ if group.get('scaled_optm', False):
485
+ update = scale_update(p, update, update_scaling, vector_state=state.get('spectral_v'))
486
+ else:
487
+ update.mul_(update_scaling)
488
+
427
489
  # Parameter Update
428
490
  param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
429
491