adv-optm 2.4.dev22__py3-none-any.whl → 2.4.dev24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
adv_optm/__init__.py CHANGED
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "SinkSGD_adv",
21
21
  ]
22
22
 
23
- __version__ = "2.4.dev22"
23
+ __version__ = "2.4.dev24"
@@ -99,7 +99,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
99
99
  use_muon (bool | None): whether to use Muon or AuxAdamW. MUST be provided
100
100
  either here or via `optim_type` in parameter groups. (default: None)
101
101
  state_precision (str): Precision for Muon optimizer states. Options: 'auto' (parameter dtype), 'fp32',
102
- 'bf16_sr' (BF16 with stochastic rounding), 'fp8_sr', 'int8_sr'.
102
+ 'bf16_sr' (BF16 with stochastic rounding), 'int8_sr'.
103
103
  (default: 'auto')
104
104
  factored_2nd (bool): Factorize only the second moment (v_t) using SMMF
105
105
  low-rank compression while keeping the first moment (momentum_buffer)
@@ -123,7 +123,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
123
123
  adam_tiny_spike (float): Tiny spike for Kourkoutas-β. (default: 1e-9)
124
124
  adam_k_warmup_steps (int): Warmup steps for Kourkoutas-β. (default: 0)
125
125
  adam_spectral_normalization (bool): Enable explicit spectral normalization for AdamW. (default: False)
126
- adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp16', 'fp8_sr', 'int8_sr', 'factored'. (default: 'auto')
126
+ adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp16', 'int8_sr', 'factored'. (default: 'auto')
127
127
  adam_nnmf_factor (bool): 1-bit factored for AdamW.
128
128
  adam_factored_2nd (bool): Factorize only the second moment (v_t) for AuxAdam. (default: False)
129
129
  """
@@ -157,7 +157,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
157
157
  # Boolean to spilt param
158
158
  use_muon: bool | None = None,
159
159
  # States precision (Muon path)
160
- state_precision: str = "auto", # 'fp32', 'bf16_sr', 'fp8_sr', 'int8_sr'
160
+ state_precision: str = "auto", # 'fp32', 'bf16_sr', 'int8_sr'
161
161
  # Factorized second moment only
162
162
  factored_2nd: bool = False,
163
163
  # Update geometry parameters
@@ -220,7 +220,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
220
220
  state_precision = "factored"
221
221
 
222
222
  state_precision = state_precision.lower()
223
- valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
223
+ valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
224
224
  if state_precision not in valid_precisions:
225
225
  raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
226
226
 
@@ -374,6 +374,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
374
374
  d1, d2 = state['effective_shape']
375
375
  state['mu_vbuf_nmf'] = torch.zeros(d1, device=p.device, dtype=torch.float32)
376
376
  state['mv_vbuf_nmf'] = torch.zeros(d2, device=p.device, dtype=torch.float32)
377
+ state['shifter'] = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=p.device, dtype=torch.uint8)
377
378
  elif not group['normuon_variant']:
378
379
  init_state_tensor(state, 'second_momentum_buffer', p.shape, actual_precision, p.device, default_dtype, non_neg=True)
379
380
 
@@ -454,8 +455,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
454
455
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
455
456
  elif actual_precision == 'int8_sr':
456
457
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
457
- elif actual_precision == 'fp8_sr':
458
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
459
458
  else:
460
459
  adam_step_param = Muon_AuxAdam._adam_step_parameter
461
460
 
@@ -475,8 +474,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
475
474
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
476
475
  elif actual_precision == 'int8_sr':
477
476
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
478
- elif actual_precision == 'fp8_sr':
479
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
480
477
  if group['low_rank_ortho']:
481
478
  random_G_sketch = param_update._get_random_noise_for_low_rank_ortho(p, group['ortho_rank'])
482
479
  else:
@@ -84,7 +84,7 @@ class AdamW_adv(torch.optim.Optimizer):
84
84
  while only factorizing the second moment. (default: False)
85
85
  state_precision (str): Precision method for Adopt states. Options: 'auto'
86
86
  (parameter precision), 'fp32', 'factored' (SMMF low-rank FP32), 'bf16_sr' (with
87
- stochastic rounding), 'fp16' , 'fp8_sr', 'int8_sr'. (default: 'auto')
87
+ stochastic rounding), 'fp16' , 'int8_sr'. (default: 'auto')
88
88
  """
89
89
 
90
90
  def __init__(
@@ -124,7 +124,7 @@ class AdamW_adv(torch.optim.Optimizer):
124
124
  centered_wd: float = 0.0,
125
125
  centered_wd_mode: str = 'float8',
126
126
  # States precision
127
- state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'fp8_sr', 'int8_sr'.
127
+ state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'int8_sr'.
128
128
  # Factorized second moment only
129
129
  factored_2nd: bool = False,
130
130
  # SMMF factorization (legacy)
@@ -145,7 +145,7 @@ class AdamW_adv(torch.optim.Optimizer):
145
145
  raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
146
146
 
147
147
  state_precision = state_precision.lower()
148
- valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
148
+ valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
149
149
  if state_precision not in valid_precisions:
150
150
  raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
151
151
 
@@ -264,6 +264,7 @@ class AdamW_adv(torch.optim.Optimizer):
264
264
  d1, d2 = state['effective_shape']
265
265
  state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=torch.float32)
266
266
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=torch.float32)
267
+ state['shifter'] = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=device, dtype=torch.uint8)
267
268
  else:
268
269
  init_state_tensor(state, 'exp_avg_sq', p.shape, actual_precision, p.device, dtype, non_neg=True)
269
270
 
@@ -314,8 +315,6 @@ class AdamW_adv(torch.optim.Optimizer):
314
315
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
315
316
  elif group['actual_state_precision'] == 'int8_sr':
316
317
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
317
- elif group['actual_state_precision'] == 'fp8_sr':
318
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
319
318
  step_param_fn = self._compiled_step_parameter
320
319
  else:
321
320
  step_param_fn = self._step_parameter
@@ -88,7 +88,7 @@ class Adopt_adv(torch.optim.Optimizer):
88
88
  while only factorizing the second moment. (default: False)
89
89
  state_precision (str): Precision method for Adopt states. Options: 'auto'
90
90
  (parameter precision), 'fp32', 'factored' (SMMF low-rank FP32), 'bf16_sr' (with
91
- stochastic rounding), 'fp16' , 'fp8_sr', 'int8_sr'. (default: 'auto')
91
+ stochastic rounding), 'fp16' , 'int8_sr'. (default: 'auto')
92
92
  """
93
93
 
94
94
  def __init__(
@@ -126,7 +126,7 @@ class Adopt_adv(torch.optim.Optimizer):
126
126
  centered_wd: float = 0.0,
127
127
  centered_wd_mode: str = 'float8',
128
128
  # States precision
129
- state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'fp8_sr', 'int8_sr'.
129
+ state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'int8_sr'.
130
130
  # Factorized second moment only
131
131
  factored_2nd: bool = False,
132
132
  # SMMF factorization (legacy)
@@ -148,7 +148,7 @@ class Adopt_adv(torch.optim.Optimizer):
148
148
 
149
149
 
150
150
  state_precision = state_precision.lower()
151
- valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
151
+ valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
152
152
  if state_precision not in valid_precisions:
153
153
  raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
154
154
 
@@ -236,7 +236,7 @@ class Adopt_adv(torch.optim.Optimizer):
236
236
 
237
237
  dtype = torch.float32 if (state['factored'] or req_precision == 'factored') else p.dtype
238
238
 
239
- vt_dtype = torch.float32 if (state['factored'] or state['factored_2nd'] or req_precision in ['factored', 'bf16_sr', 'fp8_sr', 'int8_sr']) else dtype
239
+ vt_dtype = torch.float32 if (state['factored'] or state['factored_2nd'] or req_precision in ['factored', 'bf16_sr', 'int8_sr']) else dtype
240
240
  vt_init = grad.pow(2).to(vt_dtype) * (1 - group['betas'][1])
241
241
 
242
242
  if state['factored']:
@@ -262,6 +262,7 @@ class Adopt_adv(torch.optim.Optimizer):
262
262
  state['effective_shape'] = _get_effective_shape(p.numel())
263
263
  d1, d2 = state['effective_shape']
264
264
  state['mu_v_nmf'], state['mv_v_nmf'] = _nnmf(vt_init.view(d1, d2))
265
+ state['shifter'] = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=p.device, dtype=torch.uint8)
265
266
  else:
266
267
  init_state_tensor(state, 'exp_avg_sq', p.shape, actual_precision, p.device, dtype)
267
268
  set_state(state, 'exp_avg_sq', vt_init, actual_precision, None, non_neg=True)
@@ -316,8 +317,6 @@ class Adopt_adv(torch.optim.Optimizer):
316
317
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
317
318
  elif group['actual_state_precision'] == 'int8_sr':
318
319
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
319
- elif group['actual_state_precision'] == 'fp8_sr':
320
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
321
320
  step_param_fn = self._compiled_step_parameter
322
321
  else:
323
322
  lr = group['lr']
@@ -33,8 +33,6 @@ class Lion_adv(torch.optim.Optimizer):
33
33
  stochastic_rounding (bool, optional): whether to use stochastic
34
34
  rounding for BF16 parameter updates (default: True).
35
35
  orthogonal_gradient (bool): whether to orthogonalize the gradient (default: False).
36
- clip_threshold (float, optional): whether to clip the gradients norm
37
- per-parameter (default: 0.0).
38
36
  kappa_p (float, optional): The p-value for the Lp-norm in Lion-K (domain [1.0, 2.0]).
39
37
  - 1.0: Standard Lion (sign update).
40
38
  - 2.0: Spherical Lion (normalized L2 update).
@@ -47,7 +47,7 @@ class Muon_adv(torch.optim.Optimizer):
47
47
  use_muon (bool | None): whether to use Muon or AuxAdamW. MUST be provided
48
48
  either here or via `optim_type` in parameter groups. (default: None)
49
49
  state_precision (str): Precision for Muon optimizer states. Options: 'auto' (parameter dtype), 'fp32',
50
- 'bf16_sr' (BF16 with stochastic rounding), 'fp8_sr', 'int8_sr'.
50
+ 'bf16_sr' (BF16 with stochastic rounding), 'int8_sr'.
51
51
  (default: 'auto')
52
52
  low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
53
53
  projects the update to a lower rank before orthogonalization.
@@ -98,7 +98,7 @@ class Muon_adv(torch.optim.Optimizer):
98
98
  adam_tiny_spike (float): Tiny spike for Kourkoutas-β. (default: 1e-9)
99
99
  adam_k_warmup_steps (int): Warmup steps for Kourkoutas-β. (default: 0)
100
100
  adam_spectral_normalization (bool): Enable explicit spectral normalization for AdamW. (default: False)
101
- adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp16', 'fp8_sr', 'int8_sr', 'factored'. (default: 'auto')
101
+ adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp16', 'int8_sr', 'factored'. (default: 'auto')
102
102
  adam_nnmf_factor (bool): 1-bit factored for AdamW.
103
103
  adam_factored_2nd (bool): Factorize only the second moment (v_t) for AuxAdam. (default: False)
104
104
  """
@@ -130,7 +130,7 @@ class Muon_adv(torch.optim.Optimizer):
130
130
  # Boolean to spilt param
131
131
  use_muon: bool | None = None,
132
132
  # States precision (Muon path)
133
- state_precision: str = "auto", # 'fp32', 'bf16_sr', 'fp8_sr', 'int8_sr'
133
+ state_precision: str = "auto", # 'fp32', 'bf16_sr', 'int8_sr'
134
134
  # Low-rank Muon
135
135
  low_rank_ortho: bool = False,
136
136
  ortho_rank: int = 128,
@@ -193,7 +193,7 @@ class Muon_adv(torch.optim.Optimizer):
193
193
  state_precision = "factored"
194
194
 
195
195
  state_precision = state_precision.lower()
196
- valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
196
+ valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
197
197
  if state_precision not in valid_precisions:
198
198
  raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
199
199
 
@@ -406,8 +406,6 @@ class Muon_adv(torch.optim.Optimizer):
406
406
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
407
407
  elif actual_precision == 'int8_sr':
408
408
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
409
- elif actual_precision == 'fp8_sr':
410
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
411
409
  else:
412
410
  adam_step_param = Muon_AuxAdam._adam_step_parameter
413
411
 
@@ -427,8 +425,6 @@ class Muon_adv(torch.optim.Optimizer):
427
425
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
428
426
  elif actual_precision == 'int8_sr':
429
427
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
430
- elif actual_precision == 'fp8_sr':
431
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
432
428
  if group['low_rank_ortho']:
433
429
  random_G_sketch = param_update._get_random_noise_for_low_rank_ortho(p, group['ortho_rank'])
434
430
  else:
@@ -124,7 +124,7 @@ class Prodigy_adv(torch.optim.Optimizer):
124
124
  nesterov: bool = False,
125
125
  nesterov_coef: float | None = None,
126
126
  # States precision
127
- state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'fp8_sr', 'int8_sr'.
127
+ state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'int8_sr'.
128
128
  # Factorized second moment only
129
129
  factored_2nd: bool = False,
130
130
  # SMMF factorization (legacy)
@@ -168,7 +168,7 @@ class Prodigy_adv(torch.optim.Optimizer):
168
168
  raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
169
169
 
170
170
  state_precision = state_precision.lower()
171
- valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
171
+ valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
172
172
  if state_precision not in valid_precisions:
173
173
  raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
174
174
 
@@ -311,6 +311,7 @@ class Prodigy_adv(torch.optim.Optimizer):
311
311
  d1, d2 = state['effective_shape']
312
312
  state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=torch.float32)
313
313
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=torch.float32)
314
+ state['shifter'] = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=p.device, dtype=torch.uint8)
314
315
  else:
315
316
  init_state_tensor(state, 'exp_avg_sq', p.shape, actual_precision, p.device, dtype, non_neg=True)
316
317
 
@@ -358,8 +359,6 @@ class Prodigy_adv(torch.optim.Optimizer):
358
359
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
359
360
  elif group['actual_state_precision'] == 'int8_sr':
360
361
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
361
- elif group['actual_state_precision'] == 'fp8_sr':
362
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
363
362
  step_param_fn = self._compiled_step_parameter
364
363
  else:
365
364
  d = group['d']
@@ -44,7 +44,7 @@ class SignSGD_adv(torch.optim.Optimizer):
44
44
  'int4': Uses 4-bit block-wise quantization (block size 32).
45
45
  state_precision (str): Precision method for Adopt states. Options: 'auto'
46
46
  (parameter precision), 'fp32', 'factored' (SMMF low-rank FP32), 'bf16_sr' (with
47
- stochastic rounding), 'fp16' , 'fp8_sr', 'int8_sr'. (default: 'auto')
47
+ stochastic rounding), 'fp16' , 'int8_sr'. (default: 'auto')
48
48
  nnmf_factor (bool): whether to use the factorization or use the
49
49
  uncompressed optimizer. (default: True)
50
50
  """
@@ -70,13 +70,13 @@ class SignSGD_adv(torch.optim.Optimizer):
70
70
  nesterov_coef: float | None = None,
71
71
  # Normalization then Momentum
72
72
  normed_momentum: bool = False,
73
- # SNR Precondition
73
+ # SNR Precondition (requires normed_momentum)
74
74
  snr_cond: bool = False,
75
75
  # Centered WD
76
76
  centered_wd: float = 0.0,
77
77
  centered_wd_mode: str = 'float8',
78
78
  # States precision
79
- state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'fp8_sr', 'int8_sr'.
79
+ state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'int8_sr'.
80
80
  # Spectral Normed Optimizer
81
81
  spectral_normalization: bool = False,
82
82
  # SMMF factorization
@@ -95,7 +95,7 @@ class SignSGD_adv(torch.optim.Optimizer):
95
95
  raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum")
96
96
 
97
97
  state_precision = state_precision.lower()
98
- valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
98
+ valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
99
99
  if state_precision not in valid_precisions:
100
100
  raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
101
101
 
@@ -230,8 +230,6 @@ class SignSGD_adv(torch.optim.Optimizer):
230
230
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
231
231
  elif group['actual_state_precision'] == 'int8_sr':
232
232
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
233
- elif group['actual_state_precision'] == 'fp8_sr':
234
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
235
233
 
236
234
  if group.get('stochastic_sign', False) and not is_vector:
237
235
  random_noise_tensor = param_update._get_random_noise_for_sso(p)
@@ -254,7 +252,8 @@ class SignSGD_adv(torch.optim.Optimizer):
254
252
  nesterov = group.get('nesterov', False)
255
253
  nesterov_coef = group.get('nesterov_coef', None)
256
254
  sso = group.get('stochastic_sign', False)
257
- snr_cond = group.get('snr_cond', False) and group.get('normed_momentum', False) and momentum > 0
255
+ normed_mt = group.get('normed_momentum', False)
256
+ snr_cond = group.get('snr_cond', False) and normed_mt and momentum > 0
258
257
 
259
258
  denom = None
260
259
  wd_target = None
@@ -263,7 +262,7 @@ class SignSGD_adv(torch.optim.Optimizer):
263
262
  if group["orthogonal_gradient"]:
264
263
  grad = _orthogonalize_gradient(p, grad)
265
264
 
266
- if group.get('normed_momentum', False):
265
+ if normed_mt:
267
266
  if sso:
268
267
  grad = apply_stochastic_sign_(grad, noise=random_noise_tensor, is_vector=is_vector)
269
268
  else:
@@ -285,7 +284,12 @@ class SignSGD_adv(torch.optim.Optimizer):
285
284
 
286
285
  if nesterov:
287
286
  nv_coef = momentum if nesterov_coef is None else nesterov_coef
288
- raw_update = grad_reshaped.lerp(exp_avg, nv_coef)
287
+ if normed_mt:
288
+ # Scale the normalized gradient down to match the buffer's variance
289
+ ema_std = math.sqrt((1 - momentum) / (1 + momentum))
290
+ raw_update = (grad_reshaped * ema_std).lerp_(exp_avg, nv_coef)
291
+ else:
292
+ raw_update = grad.lerp(exp_avg, nv_coef)
289
293
  else:
290
294
  raw_update = exp_avg.clone()
291
295
 
@@ -309,7 +313,12 @@ class SignSGD_adv(torch.optim.Optimizer):
309
313
 
310
314
  if nesterov:
311
315
  nv_coef = momentum if nesterov_coef is None else nesterov_coef
312
- raw_update = grad.lerp(exp_avg, nv_coef)
316
+ if normed_mt:
317
+ # Scale the normalized gradient down to match the buffer's variance
318
+ ema_std = math.sqrt((1 - momentum) / (1 + momentum))
319
+ raw_update = (grad * ema_std).lerp_(exp_avg, nv_coef)
320
+ else:
321
+ raw_update = grad.lerp(exp_avg, nv_coef)
313
322
  else:
314
323
  raw_update = exp_avg.clone()
315
324
 
@@ -42,7 +42,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
42
42
  nnmf_factor (bool): whether to use factorization or disable it. (default: False)
43
43
  state_precision (str): Precision method for states. Options: 'auto'
44
44
  (parameter precision), 'fp32', 'factored' (SMMF low-rank FP32), 'bf16_sr',
45
- 'fp8_sr', 'int8_sr'. (default: 'auto')
45
+ 'int8_sr'. (default: 'auto')
46
46
  compiled_optimizer (bool): Compiles the core step function using torch.compile
47
47
  for faster execution. (default: False)
48
48
  """
@@ -58,7 +58,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
58
58
  orthogonal_sinkhorn: bool = False,
59
59
  # Normalization then Momentum
60
60
  normed_momentum: bool = False,
61
- # SNR Precondition
61
+ # SNR Precondition (requires normed_momentum)
62
62
  snr_cond: bool = False,
63
63
  # Nesterov Momentum
64
64
  nesterov: bool = False,
@@ -93,7 +93,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
93
93
  raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum")
94
94
 
95
95
  state_precision = state_precision.lower()
96
- valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
96
+ valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
97
97
  if state_precision not in valid_precisions:
98
98
  raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
99
99
 
@@ -209,8 +209,6 @@ class SinkSGD_adv(torch.optim.Optimizer):
209
209
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
210
210
  elif group['actual_state_precision'] == 'int8_sr':
211
211
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
212
- elif group['actual_state_precision'] == 'fp8_sr':
213
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
214
212
  step_param_fn = self._compiled_step_parameter
215
213
  else:
216
214
  step_param_fn = self._step_parameter
@@ -226,6 +224,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
226
224
  orthogonal_sinkhorn = group['orthogonal_sinkhorn']
227
225
 
228
226
  momentum = group['momentum']
227
+ normed_mt = group.get('normed_momentum', False)
229
228
  nesterov = group['nesterov']
230
229
  nesterov_coef = group.get('nesterov_coef', None)
231
230
  snr_cond = group.get('snr_cond', False)
@@ -238,7 +237,10 @@ class SinkSGD_adv(torch.optim.Optimizer):
238
237
  wd_target = None
239
238
  cwd_target = None
240
239
 
241
- if group.get('normed_momentum', False):
240
+ if group["orthogonal_gradient"]:
241
+ grad = _orthogonalize_gradient(p, grad)
242
+
243
+ if normed_mt:
242
244
  if not is_vector:
243
245
  # Sinkhorn iterative normalization
244
246
  grad = apply_sr_sinkhorn(grad, iters=sinkhorn_iterations, p=p, ortho_project=orthogonal_sinkhorn)
@@ -246,9 +248,6 @@ class SinkSGD_adv(torch.optim.Optimizer):
246
248
  # For vectors, apply sign operation
247
249
  grad = grad.sign_()
248
250
 
249
- if group["orthogonal_gradient"]:
250
- grad = _orthogonalize_gradient(p, grad)
251
-
252
251
  if state['factored']:
253
252
  d1, d2 = state['effective_shape']
254
253
  grad_reshaped = grad.view(d1, d2)
@@ -272,7 +271,12 @@ class SinkSGD_adv(torch.optim.Optimizer):
272
271
 
273
272
  if nesterov:
274
273
  nv_coef = momentum if nesterov_coef is None else nesterov_coef
275
- update = grad_reshaped.lerp(buf, nv_coef)
274
+ if normed_mt:
275
+ # Scale the normalized gradient down to match the buffer's variance
276
+ ema_std = math.sqrt((1 - momentum) / (1 + momentum))
277
+ update = (grad_reshaped * ema_std).lerp_(buf, nv_coef)
278
+ else:
279
+ update = grad_reshaped.lerp(buf, nv_coef)
276
280
  else:
277
281
  update = buf.clone()
278
282
  else:
@@ -301,7 +305,12 @@ class SinkSGD_adv(torch.optim.Optimizer):
301
305
 
302
306
  if nesterov:
303
307
  nv_coef = momentum if nesterov_coef is None else nesterov_coef
304
- update = grad.lerp(buf, nv_coef)
308
+ if normed_mt:
309
+ # Scale the normalized gradient down to match the buffer's variance
310
+ ema_std = math.sqrt((1 - momentum) / (1 + momentum))
311
+ update = (grad * ema_std).lerp_(buf, nv_coef)
312
+ else:
313
+ update = grad.lerp(buf, nv_coef)
305
314
  else:
306
315
  update = buf.clone()
307
316
  else:
@@ -56,6 +56,7 @@ def _init_auxadam_state(self, p, group):
56
56
  d1, d2 = state['effective_shape']
57
57
  state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=torch.float32)
58
58
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=torch.float32)
59
+ state['shifter'] = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=device, dtype=torch.uint8)
59
60
  else:
60
61
  init_state_tensor(state, 'exp_avg_sq', p.shape, actual_precision, p.device, dtype, non_neg=True)
61
62
 
@@ -1,4 +1,5 @@
1
1
  import torch
2
+ import math
2
3
 
3
4
  def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
4
5
  """
@@ -17,3 +18,63 @@ def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor
17
18
  g_orth_norm = g_orth.norm(2).add_(1e-30)
18
19
  g_orth_scaled = g_orth * (g_norm / g_orth_norm)
19
20
  return g_orth_scaled.view(original_shape).to(original_dtype)
21
+
22
+
23
+ def iterative_ortho_project(p: torch.Tensor, update: torch.Tensor, iters: int = 5) -> torch.Tensor:
24
+ """
25
+ Applies iterative alternating orthogonal projection to a 2D matrix.
26
+ Projects the update to be orthogonal to the parameter matrix along
27
+ rows and columns sequentially, alternating dimensions.
28
+ Inspired from Sinkhorn algorithm, 2 iterations is enough to converge
29
+ to cosine similarity of -1e4 to -1e-5 (semi orthogonal).
30
+ """
31
+ # 1D Vector Case fallback to the standard OrthoGrad
32
+ is_vector = p.ndim < 2 or getattr(p, '_is_dora_scale', False) or getattr(p, 'is_vector', False)
33
+ if is_vector:
34
+ return _orthogonalize_gradient(p, update)
35
+
36
+ original_shape = update.shape
37
+
38
+ # 2D+ Matrix Case
39
+ update_2d = update.view(update.shape[0], -1)
40
+ param_2d = p.view(p.shape[0], -1)
41
+
42
+ m, n = update_2d.shape
43
+
44
+ # Dynamically determine the order based on aspect ratio
45
+ row_first = m > n
46
+ dim = 0 if row_first else 1
47
+
48
+ p_norm_sq_dim = torch.sum(param_2d * param_2d, dim=dim, keepdim=True).add_(1e-30)
49
+ p_norm_sq_adim = torch.sum(param_2d * param_2d, dim=1-dim, keepdim=True).add_(1e-30)
50
+
51
+ for _ in range(iters):
52
+ # First dimension
53
+ update_2d = _ortho_normed_dim(param_2d, update_2d, p_norm_sq_dim, dim)
54
+ # Second dimension
55
+ update_2d = _ortho_normed_dim(param_2d, update_2d, p_norm_sq_adim, 1 - dim)
56
+
57
+ return update_2d.view(original_shape)
58
+
59
+
60
+ def _ortho_normed_dim(p_2d: torch.Tensor, update_2d: torch.Tensor, p_norm_sq: torch.Tensor, dim: int) -> torch.Tensor:
61
+ """
62
+ Projects the update to be orthogonal to p along 'dim' and dynamically restores
63
+ the original magnitude of that dimension pre-projection.
64
+ """
65
+ # Record target magnitude before projection
66
+ norm_lb = 1 / math.sqrt(update_2d.shape[dim])
67
+ target_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
68
+
69
+ # Project: g_orth = g - (p * <p, g> / ||p||^2)
70
+ dot_prod = torch.sum(p_2d * update_2d, dim=dim, keepdim=True)
71
+ proj = dot_prod / p_norm_sq
72
+
73
+ # In-place subtraction: update_2d = update_2d - (proj * p_2d)
74
+ # Standard gamma is -1, but -1.01 proved to converge faster
75
+ update_2d.addcmul_(proj, p_2d, value=-1.01)
76
+
77
+ # Magnitude Preservation
78
+ g_orth_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
79
+ scale_factor = target_norm / g_orth_norm
80
+ return update_2d.mul_(scale_factor)
@@ -19,9 +19,9 @@ def quantize_blockwise(p, block_size, bits=8):
19
19
  # Pad to multiple of block_size
20
20
  pad_len = (block_size - (numel % block_size)) % block_size
21
21
  if pad_len > 0:
22
- val_padded = F.pad(val_flat, (0, pad_len), mode='replicate')
22
+ val_padded = F.pad(val_flat.unsqueeze(0), (0, pad_len), mode='replicate').squeeze(0)
23
23
  else:
24
- val_padded = val_flat
24
+ val_padded = val_flat.clone()
25
25
 
26
26
  # Block Reshape
27
27
  val_blocks = val_padded.view(-1, block_size).float()
@@ -29,13 +29,14 @@ def quantize_blockwise(p, block_size, bits=8):
29
29
  # Calc Stats
30
30
  min_vals, max_vals = torch.aminmax(val_blocks, dim=1, keepdim=True)
31
31
 
32
- # Scale calculation
33
- max_int = (1 << bits) - 1
34
- scales = (max_vals - min_vals).div_(float(max_int))
32
+ # Scale calculation for signed ints
33
+ q_min = -(1 << (bits - 1))
34
+ q_max = (1 << (bits - 1)) - 1
35
+ scales = (max_vals - min_vals).div_(float(q_max - q_min))
35
36
  scales.masked_fill_(scales == 0, 1.0)
36
37
 
37
- # Quantize: (val - min) / scale
38
- quantized = val_blocks.sub_(min_vals).div_(scales).round_().clamp_(0, max_int).to(torch.uint8)
38
+ # Quantize: (val - min) / scale + q_min
39
+ quantized = val_blocks.sub_(min_vals).div_(scales).add_(q_min).round_().clamp_(q_min, q_max).to(torch.int8)
39
40
 
40
41
  return quantized, scales.squeeze(1), min_vals.squeeze(1)
41
42
 
@@ -64,9 +65,10 @@ def _init_anchor(p, state, group):
64
65
  q_blocks, scales, mins = quantize_blockwise(p, block_size=32, bits=4)
65
66
  q_flat = q_blocks.view(-1)
66
67
  # Vectorized packing: High bits | Low bits
67
- packed = (q_flat[0::2] << 4) | q_flat[1::2]
68
+ # Masking with 0x0F prevents two's complement sign extension from overwriting bits
69
+ packed = ((q_flat[0::2] & 0x0F) << 4) | (q_flat[1::2] & 0x0F)
68
70
 
69
- state['anchor_data'] = packed
71
+ state['anchor_data'] = packed.to(torch.int8)
70
72
  state['anchor_scale'] = scales.to(p.dtype)
71
73
  state['anchor_min'] = mins.to(p.dtype)
72
74
 
@@ -88,24 +90,29 @@ def dequantize_anchor(p, state, group, dtype):
88
90
  orig_shape = p.shape
89
91
  orig_numel = p.numel()
90
92
 
91
- if mode == 'int4' and anchor_data.dtype == torch.uint8:
93
+ if mode == 'int4' and anchor_data.dtype == torch.int8:
92
94
  block_size = 32
93
- unpacked = torch.empty(anchor_data.numel() * 2, dtype=torch.uint8, device=anchor_data.device)
95
+ unpacked = torch.empty(anchor_data.numel() * 2, dtype=torch.int8, device=anchor_data.device)
96
+
97
+ # Unpack utilizing standard PyTorch arithmetic shift (sign extends natively)
94
98
  unpacked[0::2] = anchor_data >> 4
95
- unpacked[1::2] = anchor_data & 0x0F
99
+ unpacked[1::2] = (anchor_data << 4) >> 4
100
+
96
101
  quantized_blocks = unpacked.view(-1, block_size)
102
+ q_min = -8
97
103
 
98
- elif mode == 'int8' and anchor_data.dtype == torch.uint8:
104
+ elif mode == 'int8' and anchor_data.dtype == torch.int8:
99
105
  block_size = 128
100
106
  quantized_blocks = anchor_data
107
+ q_min = -128
101
108
 
102
109
  else:
103
110
  # Unrecognised mode/dtype combination
104
111
  return anchor_data.to(dtype)
105
112
 
106
- # Core Dequantization: (q * scale) + min
113
+ # Core Dequantization: (q - q_min) * scale + min
107
114
  anchor_blocks = (
108
- quantized_blocks.float() * scales.float().unsqueeze(1)
115
+ (quantized_blocks.float() - q_min) * scales.float().unsqueeze(1)
109
116
  + mins.float().unsqueeze(1)
110
117
  )
111
118
 
@@ -236,74 +236,6 @@ def copy_stochastic_(target: Tensor, source: Tensor, inplace: bool = False):
236
236
  _copy_stochastic_core_(target, source, random_int_tensor, inplace)
237
237
  del random_int_tensor
238
238
 
239
- def _get_random_int_for_fp8_sr(source: torch.Tensor) -> torch.Tensor:
240
- """
241
- Generates a random int32 tensor for FP8 stochastic rounding.
242
- This function is not torch.compile-path friendly due to its use of torch.Generator.
243
- """
244
- device = source.device
245
-
246
- if device not in _generators:
247
- set_seed(device)
248
-
249
- # TODO: this is a workaround until torch compile error
250
- # NotImplementedError: UserDefinedObjectVariable(generator) is fixed
251
- generator = _generators[device]
252
-
253
- # FP8 e4m3fn always preserves exactly 3 mantissa bits from FP32 (23 bits).
254
- # We need uniform noise in [0, 2^20 - 1]
255
- return torch.randint(
256
- size=source.shape,
257
- device=source.device,
258
- dtype=torch.int32,
259
- low=0,
260
- high=1048576, # 1 << 20
261
- generator=generator,
262
- )
263
-
264
-
265
- def _copy_fp8_stochastic_core_(
266
- target: torch.Tensor,
267
- source: torch.Tensor,
268
- scale: torch.Tensor,
269
- random_int_tensor: torch.Tensor
270
- ):
271
- """
272
- Core logic for FP8 (float8_e4m3fn) stochastic rounding using a pre-computed
273
- random integer tensor.
274
- """
275
- # Scale the source to FP32
276
- buffer = (source * scale).to(torch.float32)
277
-
278
- # Extract magnitude and sign
279
- sign_x = torch.sign(buffer)
280
- buffer.abs_()
281
-
282
- # Create and apply the magic offset
283
- offset = (buffer < 0.015625).to(torch.float32).mul_(0.015625)
284
- buffer.add_(offset)
285
-
286
- # Apply Stochastic Rounding
287
- buffer_int = buffer.view(torch.int32)
288
- buffer_int.add_(random_int_tensor)
289
- buffer_int.bitwise_and_(-1048576)
290
-
291
- # Remove offset and reapply sign
292
- buffer = buffer_int.view(torch.float32)
293
- buffer.sub_(offset)
294
- buffer.mul_(sign_x)
295
-
296
- target.copy_(buffer.to(torch.float8_e4m3fn))
297
-
298
-
299
- def copy_fp8_stochastic_(target: torch.Tensor, source: torch.Tensor, scale: torch.Tensor):
300
- """
301
- Stochastic rounding implementation for FP8 e4m3fn states.
302
- """
303
- random_int_tensor = _get_random_int_for_fp8_sr(source)
304
- _copy_fp8_stochastic_core_(target, source, scale, random_int_tensor)
305
- del random_int_tensor
306
-
307
239
 
308
240
  def _get_random_int_for_8bit_sr(source: torch.Tensor, numel: int | None = None) -> torch.Tensor:
309
241
  """
@@ -29,13 +29,13 @@ def scale_update(
29
29
 
30
30
  # DoRA Magnitude Scales (1D) or 1D Bias/Norm layers
31
31
  if p.ndim < 2 or is_dora_scale:
32
- return rms_normalization(update, dim=None, lr=lr)
32
+ return max_abs_normalization(update, dim=None, lr=lr)
33
33
 
34
34
  # OFT Block Parameters: shape (k, C(b,2))
35
35
  # Normalise by max per-block row norm so that
36
36
  # ‖ΔR_block‖_spec = max_i ‖ΔRᵢ‖_spec ≤ 2 · max_i ‖Δθᵢ‖₂ ≤ target_scale · lr
37
37
  if is_oft:
38
- return max_row_norm_normalization(update, lr)
38
+ return max_row_norm_oft_normalization(p, update, lr)
39
39
 
40
40
  # LoRA Factors or Full Finetuning weights
41
41
  # Scales update to maintain consistent spectral norm across different layer sizes and ranks.
@@ -124,9 +124,7 @@ def init_spectral_norm(state: dict, p: torch.Tensor):
124
124
  @torch.no_grad()
125
125
  def l2_normalization(update: torch.Tensor, dim: int | None, lr: float) -> torch.Tensor:
126
126
  """Performs L2 normalization on the update tensor."""
127
- n = update.numel() if dim is None else update.shape[dim]
128
- norm_eps = 1 / math.sqrt(n)
129
- norm = torch.linalg.vector_norm(update, ord=2, dim=dim, keepdim=True).clamp_min(norm_eps)
127
+ norm = torch.linalg.vector_norm(update, ord=2, dim=dim, keepdim=True).clamp_min(1e-12)
130
128
  return update.mul_(lr / norm)
131
129
 
132
130
 
@@ -134,16 +132,25 @@ def l2_normalization(update: torch.Tensor, dim: int | None, lr: float) -> torch.
134
132
  def rms_normalization(update: torch.Tensor, dim: int | None, lr: float) -> torch.Tensor:
135
133
  """Performs Root Mean Square normalization on the update tensor."""
136
134
  n = update.numel() if dim is None else update.shape[dim]
137
- norm_eps = 1 / math.sqrt(n)
138
- norm = torch.linalg.vector_norm(update, ord=2, dim=dim, keepdim=True).clamp_min(norm_eps)
135
+ norm = torch.linalg.vector_norm(update, ord=2, dim=dim, keepdim=True).clamp_min(1e-12)
139
136
  scale_n = math.sqrt(n)
140
137
  return update.mul_(lr * scale_n / norm)
141
138
 
142
139
  @torch.no_grad()
143
- def max_row_norm_normalization(
140
+ def max_abs_normalization(update: torch.Tensor, dim: int | None, lr: float) -> torch.Tensor:
141
+ """
142
+ Performs L-infinity (Max Absolute) normalization.
143
+ Strictly bounds the maximum update of any single element to 'lr'.
144
+ """
145
+ # ord=float('inf') computes the maximum absolute value
146
+ norm = torch.linalg.vector_norm(update, ord=float('inf'), dim=dim, keepdim=True).clamp_min(1e-12)
147
+ return update.mul_(lr / norm)
148
+
149
+ @torch.no_grad()
150
+ def max_row_norm_oft_normalization(
151
+ p: torch.Tensor,
144
152
  update: torch.Tensor,
145
153
  lr: float,
146
- target_scale: float = 0.5,
147
154
  ) -> torch.Tensor:
148
155
  """
149
156
  Normalizes OFT parameter updates by the maximum per-block (row) L2 norm.
@@ -154,31 +161,43 @@ def max_row_norm_normalization(
154
161
 
155
162
  ‖ΔR_block‖_spec = max_i ‖ΔRᵢ‖_spec ≤ 2 · max_i ‖Δθᵢ‖₂
156
163
 
157
- Unlike spectral normalization of the full (k × C(b,2)) parameter matrix,
158
- this guarantee is exact for all update distributions — including worst-case
159
- concentrated updates where all energy sits in a single block.
160
-
161
164
  Result: Var[Δyⱼ] ≤ (target_scale · lr)² = O(1) for every block configuration.
165
+ """
166
+ # keeps the effective rotation step ‖ΔR_block‖_spec ≤ lr.
167
+ target_scale = 0.5
162
168
 
163
- Args:
164
- update: OFT parameter update, shape (k, C(b,2)).
165
- lr: Learning rate.
166
- target_scale: Desired bound on max_i ‖Δθᵢ‖₂ / lr. Default 0.5 keeps
167
- the effective rotation step ‖ΔR_block‖_spec ≤ lr.
169
+ # Row norms: shape (k,) - one per block
170
+ row_norms = torch.linalg.vector_norm(update, ord=2, dim=-1)
171
+ # Stability floor: equivalent to a single-element vector norm lower bound
172
+ norm_lb = 1.0 / math.sqrt(update.shape[1])
173
+ max_norm = row_norms.max().clamp_min(norm_lb)
168
174
 
169
- Returns:
170
- Scaled update tensor (in-place).
175
+ # Get the magnitude correction factor
176
+ cayley_correction = get_oft_magnitude_correction(p)
177
+
178
+ return update.mul_(lr * cayley_correction * target_scale / max_norm)
179
+
180
+ @torch.no_grad()
181
+ def get_oft_magnitude_correction(p: torch.Tensor) -> torch.Tensor:
182
+ """
183
+ Approximates the magnitude correction of exact Riemannian preconditioning (M @ G @ M).
184
+ Neutralizes the derivative shrinkage of the Cayley transform using a scalar multiplier.
171
185
  """
172
- # Row norms: shape (k,) — one per block
173
- row_norms = torch.linalg.vector_norm(update, ord=2, dim=1)
174
- max_norm = row_norms.max()
186
+ n_el = p.shape[-1]
187
+ b = (1 + math.sqrt(1 + 8 * n_el)) / 2
175
188
 
176
- # Stability floor: equivalent to a single-element vector norm lower bound
177
- norm_eps = 1.0 / math.sqrt(update.shape[1])
178
- max_norm = max_norm.clamp_min(norm_eps)
189
+ # Calculate the squared L2 norm for each block independently.
190
+ p_norm_sq = torch.linalg.vector_norm(p, ord=2, dim=-1).square_()
191
+
192
+ # The expected shrinkage of the Cayley derivative is roughly (1 + lambda^2)^-1,
193
+ # where lambda^2 is the average eigenvalue of -Q^2.
194
+ # Since Tr(-Q^2) = 2 * ||p||_2^2, the average eigenvalue is 2 * ||p||_2^2 / b.
195
+ cayley_correction = 1.0 + (2.0 * p_norm_sq / b)
179
196
 
180
- return update.mul_(lr * target_scale / max_norm)
197
+ # Reshape correction to broadcast against the update tensor (shape (k, 1))
198
+ cayley_correction = cayley_correction.unsqueeze(-1)
181
199
 
200
+ return cayley_correction
182
201
 
183
202
  @torch.no_grad()
184
203
  def spectral_normalization(
@@ -3,7 +3,6 @@ import torch.nn.functional as F
3
3
 
4
4
  from .param_update import (
5
5
  copy_stochastic_, _copy_stochastic_core_,
6
- copy_fp8_stochastic_, _copy_fp8_stochastic_core_,
7
6
  copy_int8_blockwise_stochastic_, _copy_int8_blockwise_stochastic_core_,
8
7
  copy_int8_sym_blockwise_stochastic_, _copy_int8_sym_blockwise_stochastic_core_,
9
8
  )
@@ -22,17 +21,12 @@ def init_state_tensor(state: dict, key: str, shape: tuple, state_precision: str,
22
21
  store_dtype = torch.bfloat16
23
22
  elif state_precision == 'fp16':
24
23
  store_dtype = torch.float16
25
- elif state_precision in ['fp8', 'fp8_sr']:
26
- store_dtype = torch.float8_e4m3fn
27
24
  elif state_precision == 'int8_sr':
28
25
  store_dtype = torch.uint8 if non_neg else torch.int8
29
26
  else: # 'auto'
30
27
  store_dtype = default_dtype
31
28
 
32
- if store_dtype == getattr(torch, 'float8_e4m3fn', None):
33
- state[key] = torch.zeros(shape, device=device, dtype=store_dtype)
34
- state[f"{key}_scale"] = torch.tensor(1.0, device=device, dtype=torch.float32)
35
- elif store_dtype in (torch.uint8, torch.int8):
29
+ if store_dtype in (torch.uint8, torch.int8):
36
30
  numel = 1
37
31
  for s in shape:
38
32
  numel *= s
@@ -48,10 +42,7 @@ def get_state(state: dict, key: str, state_precision: str) -> torch.Tensor:
48
42
  Retrieves and dequantizes the state tensor to float32.
49
43
  """
50
44
  tensor = state[key]
51
- if state_precision in ['fp8', 'fp8_sr']:
52
- scale = state[f"{key}_scale"]
53
- return tensor.float() / scale
54
- elif state_precision == 'int8_sr':
45
+ if state_precision == 'int8_sr':
55
46
  scales = state[f"{key}_scale"] # (n_blocks,) fp32
56
47
  blocks, orig_shape, orig_numel = _prepare_int8_blocks(state[key], _int8_sr_BLOCK_SIZE)
57
48
 
@@ -123,19 +114,6 @@ def set_state(state: dict, key: str, value: torch.Tensor, state_precision: str,
123
114
  if state[key] is not value:
124
115
  state[key].copy_(value)
125
116
 
126
- elif state_precision == 'fp8_sr':
127
- amax = value.abs().max().clamp_min(1e-12)
128
- # Calculate amax
129
- scale = 448.0 / amax
130
-
131
- state[f"{key}_scale"].copy_(scale)
132
-
133
- # Quantize with bitwise Stochastic Rounding
134
- if random_int_state_tensor is None:
135
- copy_fp8_stochastic_(state[key], value, scale)
136
- else:
137
- _copy_fp8_stochastic_core_(state[key], value, scale, random_int_state_tensor)
138
-
139
117
  elif state_precision == 'bf16_sr':
140
118
  # Apply stochastic rounding for BF16 states
141
119
  if random_int_state_tensor is None:
@@ -195,7 +173,7 @@ def upcast_grad_for_precision(grad: torch.Tensor, state: dict, state_precision:
195
173
 
196
174
  # Low-precision storage modes benefit from FP32 accumulation to
197
175
  # maintain accuracy before quantizing back down in set_state.
198
- if state_precision in ['fp32', 'bf16_sr', 'fp8_sr', 'int8_sr', 'factored']:
176
+ if state_precision in ['fp32', 'bf16_sr', 'int8_sr', 'factored']:
199
177
  return grad.float()
200
178
 
201
179
  return grad
@@ -216,8 +194,6 @@ def fix_loaded_state_dtype(state: dict, p: torch.Tensor, group: dict) -> None:
216
194
  base_dtype = torch.float32
217
195
  elif actual_precision == 'bf16_sr':
218
196
  base_dtype = torch.bfloat16
219
- elif actual_precision in ['fp8', 'fp8_sr']:
220
- base_dtype = torch.float8_e4m3fn
221
197
  elif actual_precision == 'int8_sr':
222
198
  base_dtype = torch.uint8
223
199
  else:
@@ -245,8 +221,8 @@ def fix_loaded_state_dtype(state: dict, p: torch.Tensor, group: dict) -> None:
245
221
  if val.dtype != p.dtype:
246
222
  state[key] = val.to(p.dtype)
247
223
  elif mode in ['int8', 'int4']:
248
- if val.dtype != torch.uint8:
249
- state[key] = val.to(torch.uint8)
224
+ if val.dtype != torch.int8:
225
+ state[key] = val.to(torch.int8)
250
226
  elif mode == 'float8':
251
227
  if val.dtype != torch.float8_e4m3fn:
252
228
  state[key] = val.to(torch.float8_e4m3fn)
@@ -263,7 +239,7 @@ def fix_loaded_state_dtype(state: dict, p: torch.Tensor, group: dict) -> None:
263
239
  state[key] = val.to(torch.uint8)
264
240
  continue
265
241
 
266
- # Handle Factorized Tensors, FP8 Scales, and blockwise INT8 scale
242
+ # Handle Factorized Tensors, and blockwise INT8 scale
267
243
  if key in fp32_keys or (key.endswith('_scale') and key != 'anchor_scale'):
268
244
  if val.dtype != torch.float32:
269
245
  state[key] = val.to(torch.float32)
@@ -0,0 +1,109 @@
1
+ Metadata-Version: 2.4
2
+ Name: adv_optm
3
+ Version: 2.4.dev24
4
+ Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
+ Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
+ Author: Koratahiu
7
+ Author-email: hiuhonor@gmail.com
8
+ License: Apache 2.0
9
+ Keywords: llm,fine-tuning,memory-efficient,low-rank,compression,pytorch,optimizer,adam
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: License :: OSI Approved :: Apache Software License
12
+ Classifier: Operating System :: OS Independent
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
15
+ Requires-Python: >=3.8
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: torch>=2.1
19
+ Dynamic: author
20
+ Dynamic: author-email
21
+ Dynamic: classifier
22
+ Dynamic: description
23
+ Dynamic: description-content-type
24
+ Dynamic: home-page
25
+ Dynamic: keywords
26
+ Dynamic: license
27
+ Dynamic: license-file
28
+ Dynamic: requires-dist
29
+ Dynamic: requires-python
30
+ Dynamic: summary
31
+
32
+ # Advanced Optimizers (AIO)
33
+
34
+ A comprehensive, all-in-one collection of optimization algorithms for deep learning, designed for **maximum efficiency**, **minimal memory footprint**, and **superior performance** across diverse model architectures and training scenarios.
35
+
36
+ [![PyPI](https://img.shields.io/pypi/v/adv_optm)](https://pypi.org/project/adv_optm/)
37
+
38
+ ## 🔥 What's New
39
+
40
+ ### In 2.4.x:
41
+
42
+ This update introduces a whole refactor of the library with many new features and changes:
43
+
44
+ - New optimizers state mode option (`state_precision`) with many precision settings for the optimizer states: rank-2 factored mode (`factored`), full FP32 (`fp32`), BF16 with Stochastic Rounding (`bf16_sr`), int8/uint8 with Stochastic Rounding (`int8_sr`), FP16 (`fp16`)
45
+ - Added new powerful optimizer: SinkSGD_adv.
46
+ - Added spectral scaling option to all optimizers, achieving width/rank invariant updates.
47
+ - Added Nesterov momentum (`nesterov`) and its coef (`nesterov_coef`) to all optimizers.
48
+ - Added centered weight decay (`centered_wd`), to pull the weights toward their pre-train state (anchor)
49
+ - anchor precision can be changed to save memory (`centered_wd_mode`): full, float8, int8, int4
50
+ - Added Fisher Weight Decay option for Adam variants (`fisher_wd`).
51
+ - Paper: [FAdam...](https://arxiv.org/abs/2405.12807)
52
+ - Added Factored Second Moment option for Adam variants (`factored_2nd`). This works alongside any `state_precision` setting.
53
+ - Added Geometric Weight Decay for SinkSGD_adv and SignSGD_adv.
54
+ - Added new powerful mode: variance normalized momentum (`normed_momentum`). Which applies the optimizer normalization before the momentum (also called as Normalization then momentum NtM)
55
+ - For: AdamW_adv, SignSGD_adv, SinkSGD_adv.
56
+ - Added Variance/Confidence Preconditioning (`snr_cond`) for SignSGD_adv, SinkSGD_adv.
57
+ - Only works with `normed_momentum`.
58
+ - Technical reports: [AASS](https://koratahiu.github.io/aass/), and [sink-v](https://koratahiu.github.io/sink-v/).
59
+ - Added Adaptive Stochastic Sign with L_inf preconditioning (`stochastic_sign`) for SignSGD_Adv and Lion_adv.
60
+ - Improved CANS (`accelerated_ns`) for Muon variants, by integrating dynamic lower bound.
61
+ - Removed Simplified_AdEMAMix optimizer and its settings in other optimizers, they are now replaced by Nesterov momentum and its coef. Which is better and less hard to tune.
62
+ - Removed cautious and grams modes, as they were heuristic and not working well.
63
+ - Removed optimizers: Lion_Prodigy_adv, and Simplified_AdEMAMix.
64
+
65
+ ### in 2.1.x
66
+
67
+ - Added Signum (SignSGD with momentum): A new optimizer in the family (SignSGD_adv)
68
+ - More info coming soon.
69
+
70
+ ### in 2.0.x
71
+
72
+ * Implemented torch.compile for all advanced optimizers. Enabled via (compiled_optimizer=True) to fuse and optimize the optimizer step path.
73
+ * Better and improved 1-bit factored mode via (nnmf_factor=True).
74
+ * Various improvements across the optimizers.
75
+
76
+ ### in 1.2.x
77
+ * Added **advanced variants** of [Muon optimizer](https://kellerjordan.github.io/posts/muon/) with **features** and **settings** from recent papers.
78
+
79
+ | Optimizer | Description |
80
+ |---|---|
81
+ | `Muon_adv` | Advanced Muon implementation with CANS, NorMuon, Low-Rank ortho, etc. features. |
82
+ | `AdaMuon_adv` | Advanced AdaMuon implementation, which combines Muon's geometry with Adam-like adaptive scaling and sign-based orthogonalization. |
83
+
84
+ > *Documentation coming soon.*
85
+
86
+ * Implemented [Cautious Weight Decay](https://arxiv.org/abs/2510.12402) for all advanced optimizers.
87
+
88
+ * Improved parameter update and weight decay for **BF16** with **stochastic rounding**. The updates are now accumulated in **float32** and rounded once at the end.
89
+
90
+ * Use fused and in-place operations whenever possible for all advanced optimizers.
91
+
92
+ * **Prodigy variants** are now **50% faster** by [avoiding CUDA syncs](https://github.com/Koratahiu/Advanced_Optimizers/pull/5). Thanks to **@dxqb**!
93
+
94
+ ---
95
+
96
+ ## 📦 Installation
97
+
98
+ ```bash
99
+ pip install adv_optm
100
+ ```
101
+
102
+ ---
103
+
104
+ ## 🧠 Core Innovations
105
+
106
+ This library integrates multiple state-of-the-art optimization techniques validated through extensive research and practical training.
107
+
108
+ ---
109
+
@@ -0,0 +1,29 @@
1
+ adv_optm/__init__.py,sha256=FszXXnlL8-PPppcRuJ96wKzbflM1jf7vXLop8mK3XnI,356
2
+ adv_optm/optim/AdaMuon_adv.py,sha256=KfpOAOHMshnF2R-Cn00VRraGmAXQrVZFo4ZXMOyMXDc,33368
3
+ adv_optm/optim/AdamW_adv.py,sha256=Gq6FsIpA0RT35zuyk_lEBCngrftgbaeLIiCW0ExN6Qw,22688
4
+ adv_optm/optim/Adopt_adv.py,sha256=f423WZTq120jg4TJoCqckzPTY30TyskbBLBXaPKhdWM,22919
5
+ adv_optm/optim/Lion_adv.py,sha256=pXQRPKNstyc2u3Mq34WDpaLP-TAZJLewFLHXHMaXvbE,12475
6
+ adv_optm/optim/Muon_adv.py,sha256=nN_AdQ7dcPSHJseEdmygj7vVK7No2CpoC6pbcUiHjoE,27553
7
+ adv_optm/optim/Prodigy_adv.py,sha256=hYY1B167J-t-BPkLM_ZL3L-Kc4ud4feDnta3AjiMGoQ,26942
8
+ adv_optm/optim/SignSGD_adv.py,sha256=PXSR5l2Pze0zsz9yj7CPnRt7EzibzOafT8xF4mUlU-g,16489
9
+ adv_optm/optim/SinkSGD_adv.py,sha256=IyPoDikVHHKG6_GrGWQ-bdkqGfg08OJ6KX_FXBdWubE,16632
10
+ adv_optm/optim/__init__.py,sha256=RkpzWpEgAOdQAppTBLc5AilfJq-wn0aabNsQDj_7-4A,452
11
+ adv_optm/util/Kourkoutas.py,sha256=tTo2QbOWPhI29hWQ4ERosoZXRAFXXWDzIMCj1KMFaOE,13325
12
+ adv_optm/util/Muon_AuxAdam.py,sha256=C79uwrwqRplIVdJV1-omkNkfdPwcLrokzHetAvt6fk8,8423
13
+ adv_optm/util/Muon_util.py,sha256=w6f6aeBpuJwNhltXiVlJy8X52hlaGQACdPHFFN7-hJg,15280
14
+ adv_optm/util/OrthoGrad.py,sha256=bn9PifXd85Xb2padHFqGKw-huEpTwmckyt8gz5Zrias,3266
15
+ adv_optm/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
+ adv_optm/util/centered_decay.py,sha256=5xaxkU-Vf6CDpN6gXv6A6RUylNPXY2vsiSVZe_mCViY,4257
17
+ adv_optm/util/factorization_util.py,sha256=I75o2Cxlp4XKMBBnuRa6nDWAT9aJqNS75zv7-OGO5K8,3974
18
+ adv_optm/util/lion_k.py,sha256=b3oqTHK1spoD_Hi8wxKZ9lEWUkHejovT4piU-wUp_eI,1686
19
+ adv_optm/util/param_update.py,sha256=7-iHC0oQ72zHEfjPza27wgV6vEC77by4eayKItWqNPc,16167
20
+ adv_optm/util/scaled_optm.py,sha256=e6hY90mD_gSrDr_rDglLe6p0S2-YH3tXb71LNgm8B8Y,8925
21
+ adv_optm/util/signed_util.py,sha256=ebQn-O8pNvld3RvsJ_MAdCL4C8pBNS15Uqzr2xHChss,1873
22
+ adv_optm/util/sinkhorn.py,sha256=oFdWY243jBeuBKLBS47I5dE8DsK5OqJNDWlMBKXsxUM,5075
23
+ adv_optm/util/state_util.py,sha256=mpsyjhvtmNUey5xMWPFwhDzDTg3m4XA6Lo4RS2z-WCI,10126
24
+ adv_optm/util/update_util.py,sha256=IAaYuxrJ9kP2P4Z0v02pesRe4mRzP5SvXbLHjLrrdZw,1332
25
+ adv_optm-2.4.dev24.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
26
+ adv_optm-2.4.dev24.dist-info/METADATA,sha256=JvPcViuM3LwAF28oknH2JAkDL3vc_7rfIzwZ-KORotQ,5240
27
+ adv_optm-2.4.dev24.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
28
+ adv_optm-2.4.dev24.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
29
+ adv_optm-2.4.dev24.dist-info/RECORD,,
@@ -1,202 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: adv_optm
3
- Version: 2.4.dev22
4
- Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
- Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
- Author: Koratahiu
7
- Author-email: hiuhonor@gmail.com
8
- License: Apache 2.0
9
- Keywords: llm,fine-tuning,memory-efficient,low-rank,compression,pytorch,optimizer,adam
10
- Classifier: Programming Language :: Python :: 3
11
- Classifier: License :: OSI Approved :: Apache Software License
12
- Classifier: Operating System :: OS Independent
13
- Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
- Classifier: Topic :: Software Development :: Libraries :: Python Modules
15
- Requires-Python: >=3.8
16
- Description-Content-Type: text/markdown
17
- License-File: LICENSE
18
- Requires-Dist: torch>=2.1
19
- Dynamic: author
20
- Dynamic: author-email
21
- Dynamic: classifier
22
- Dynamic: description
23
- Dynamic: description-content-type
24
- Dynamic: home-page
25
- Dynamic: keywords
26
- Dynamic: license
27
- Dynamic: license-file
28
- Dynamic: requires-dist
29
- Dynamic: requires-python
30
- Dynamic: summary
31
-
32
- # Advanced Optimizers (AIO)
33
-
34
- A comprehensive, all-in-one collection of optimization algorithms for deep learning, designed for **maximum efficiency**, **minimal memory footprint**, and **superior performance** across diverse model architectures and training scenarios.
35
-
36
- [![PyPI](https://img.shields.io/pypi/v/adv_optm)](https://pypi.org/project/adv_optm/)
37
-
38
- ## 🔥 What's New
39
-
40
- ### in 2.1.x
41
-
42
- - Added Signum (SignSGD with momentum): A new optimizer in the family (SignSGD_adv)
43
- - More info coming soon.
44
-
45
- ### in 2.0.x
46
-
47
- * Implemented torch.compile for all advanced optimizers. Enabled via (compiled_optimizer=True) to fuse and optimize the optimizer step path.
48
- * Better and improved 1-bit factored mode via (nnmf_factor=True).
49
- * Various improvements across the optimizers.
50
-
51
- ### in 1.2.x
52
- * Added **advanced variants** of [Muon optimizer](https://kellerjordan.github.io/posts/muon/) with **features** and **settings** from recent papers.
53
-
54
- | Optimizer | Description |
55
- |---|---|
56
- | `Muon_adv` | Advanced Muon implementation with CANS, NorMuon, Low-Rank ortho, etc. features. |
57
- | `AdaMuon_adv` | Advanced AdaMuon implementation, which combines Muon's geometry with Adam-like adaptive scaling and sign-based orthogonalization. |
58
-
59
- > *Documentation coming soon.*
60
-
61
- * Implemented [Cautious Weight Decay](https://arxiv.org/abs/2510.12402) for all advanced optimizers.
62
-
63
- * Improved parameter update and weight decay for **BF16** with **stochastic rounding**. The updates are now accumulated in **float32** and rounded once at the end.
64
-
65
- * Use fused and in-place operations whenever possible for all advanced optimizers.
66
-
67
- * **Prodigy variants** are now **50% faster** by [avoiding CUDA syncs](https://github.com/Koratahiu/Advanced_Optimizers/pull/5). Thanks to **@dxqb**!
68
-
69
- ---
70
-
71
- ## 📦 Installation
72
-
73
- ```bash
74
- pip install adv_optm
75
- ```
76
-
77
- ---
78
-
79
- ## 🧠 Core Innovations
80
-
81
- This library integrates multiple state-of-the-art optimization techniques validated through extensive research and practical training, with **1-bit compression for optimizer states**:
82
-
83
- ### **Memory-Efficient Optimization (SMMF-inspired)**
84
- - **Paper**: [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
85
- - **Approach**: Uses rank-1 non-negative matrix factorization with reconstruction cycle (factor → reconstruct → update → factor)
86
- - **Innovation**:
87
- - First moment split into **1-bit sign + absolute value**
88
- - Final storage: **four factored vectors + one 1-bit sign state**
89
- - Preserves Adam-like update quality with drastically reduced memory
90
-
91
- ---
92
-
93
- ## ⚡ Performance Characteristics
94
-
95
- ### Memory Efficiency (SDXL Model – 6.5GB)
96
- | Optimizer | Memory Usage | Description |
97
- |-----------|--------------|-------------|
98
- | `Adopt_Factored` | 328 MB | 4 small vectors + 1-bit state |
99
- | `Adopt_Factored + AdEMAMix` | 625 MB | 6 small vectors + two 1-bit states |
100
-
101
- ### Speed Comparison (SDXL, Batch Size 4)
102
- | Optimizer | Speed | Notes |
103
- |-----------|-------|-------|
104
- | `Adafactor` | ~8.5s/it | Baseline |
105
- | `Adopt_Factored` | ~10s/it | +18% overhead from compression |
106
- | `Adopt_Factored + AdEMAMix` | ~12s/it | +41% overhead (3 factored states) |
107
-
108
- ---
109
-
110
- ## 🧪 Available Optimizers
111
-
112
- ### Standard Optimizers (All support `factored=True/False`)
113
- | Optimizer | Description | Best For |
114
- |-----------|-------------|----------|
115
- | `Adam_Adv` | Advanced Adam implementation | General purpose |
116
- | `Adopt_Adv` | Adam-variant with independent beta2 | Stable training for small batch size regimes |
117
- | `Prodigy_Adv` | Prodigy with D-Adaptation | Adam with automatic LR tuning |
118
- | `Lion_Adv` | Advanced Lion implementation | Memory-constrained environments |
119
- | `Prodigy_Lion_Adv` | Prodigy + Lion combination | Lion with automatic LR tuning |
120
-
121
- ---
122
-
123
- ## ⚙️ Feature Matrix
124
-
125
- | Feature | Adam_Adv | Adopt_Adv | Prodigy_Adv | Lion_Adv |
126
- |---------|----------|-----------|-------------|----------|
127
- | Factored | ✓ | ✓ | ✓ ✓ |
128
- | OrthoGrad | ✓ | ✓ | ✓ | ✓ |
129
- | atan2 | ✓ | ✓ | ✓ |✗ |
130
- | Stochastic Rounding | ✓ | ✓ | ✓ |✓ |
131
- | Fused Backward Pass | ✓ | ✓ | ✓ | ✓ |
132
- | **Kourkoutas-β** | ✓ | ✓ | ✓ | ✗ |
133
-
134
- ---
135
-
136
- ## 🛠️ Comprehensive Feature Guide
137
-
138
- ### A. Universal Safe Features
139
- *These features work with all optimizers and are generally safe to enable.*
140
-
141
- | Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
142
- |--------|-------------|-------------------|--------------------|-------------------|--------------|
143
- | **Fused Back Pass** | Fuses backward pass; gradients used immediately and memory freed on-the-fly | Memory-constrained environments | Reduces peak memory | Memory optimization | All optimizers |
144
- | **Stochastic Rounding** | Replaces nearest rounding with stochastic rounding to preserve small gradient updates in BF16 | BF16 training | Minimal overhead (<5%) | [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192) | All optimizers |
145
- | **OrthoGrad** | Removes gradient component parallel to weights to reduce overfitting | Full fine-tuning without weight decay | +33% time overhead (BS=4); less at larger BS | [Grokking at Edge](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | All optimizers |
146
- | **Factored** | Memory-efficient optimization via rank-1 1-bit factorization of optimizer states | Large models / memory-limited hardware | Adds compression overhead | [SMMF](https://arxiv.org/abs/2412.08894) | All optimizers |
147
-
148
- ### B. Individual Features
149
-
150
- | Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
151
- |--------|-------------|-------------------|--------------------|-------------------|--------------|
152
- | **atan2** | Robust epsilon replacement with built-in gradient clipping | Use for stable bounded updates (or for Adopt as it needs that) | No overhead | [Adam-atan2](https://github.com/lucidrains/adam-atan2-pytorch) | Adam/Adopt/Prodigy |
153
- | **Kourkoutas-β** | Layer-wise adaptive β₂ based on gradient “sunspike” ratio | Noisy/small/large-batch/high-LR training | No overhead | [Kourkoutas-β]() | Adam/Adopt/Prodigy |
154
-
155
- ---
156
-
157
- ## 🔍 Feature Deep Dives
158
-
159
- ### atan2
160
-
161
- - Replaces `eps` in Adam-family optimizers with a **scale-invariant**, bounded update rule.
162
- - Automatically clips updates to **[-2, 2]**, preventing destabilizing jumps.
163
- - **Highly recommended** for `Adopt_Adv`, which is prone to instability without clipping.
164
-
165
- > 📚 **Reference**:
166
- > - Paper: https://arxiv.org/abs/2407.05872
167
- > - Code: https://github.com/lucidrains/adam-atan2-pytorch
168
-
169
- ---
170
-
171
- ### **Kourkoutas-β**
172
-
173
- **Kourkoutas-β** introduces a **sunspike-driven, layer-wise adaptive second-moment decay (β₂)** as an optional enhancement for `Adam_Adv`, `Adopt_Adv`, `Prodigy_Adv`.
174
-
175
- Instead of using a fixed β₂ (e.g., 0.999 or 0.95), it **dynamically modulates β₂ per layer** based on a bounded *sunspike ratio*:
176
-
177
- - **During gradient bursts** → β₂ ↓ toward `Lower β₂` → faster reaction
178
- - **During calm phases** → β₂ ↑ toward `The Selected β₂` → stronger smoothing
179
-
180
- This is especially effective for **noisy training, small batch sizes, and high learning rates**, where gradient norms shift abruptly due to noise or aggressive LR schedules.
181
-
182
- #### Pros/Cons
183
-
184
- | **Category** | **Details** |
185
- |--------------|-------------|
186
- | ✅ **Pros** | • **Layer-wise adaptation** blends benefits of high β₂ (strong smoothing) and low β₂ (fast reaction).<br>• **Robust to sudden loss landscape shifts**, reacts quickly during gradient bursts, smooths during calm phases.<br>• **High tolerance to aggressive learning rates**. |
187
- | ⚠️ **Cons** | • **Potentially unstable at the start of training** due to unreliable early gradient norms; mitigated by using `K-β Warmup Steps`. |
188
-
189
- > 💡 **Best Practice**: Set `K_warmup_steps` equal to your standard LR warmup steps. During warmup, the optimizer uses the static `beta2`; adaptation begins only after warmup ends.
190
-
191
- > 📚 **Reference**:
192
- > - Paper: [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
193
- > - Code: [kbeta](https://github.com/sck-at-ucy/kbeta)
194
-
195
- ---
196
-
197
- ## 📚 References
198
-
199
- 1. [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192)
200
- 2. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
201
- 6. [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
202
- 7. [Scaling Exponents Across Parameterizations and Optimizers](https://arxiv.org/abs/2407.05872)
@@ -1,29 +0,0 @@
1
- adv_optm/__init__.py,sha256=_JP348MJrOkexPp6UrvAn4CYPZoZoSXHP5UZiNrPoKA,356
2
- adv_optm/optim/AdaMuon_adv.py,sha256=2ATHZhO7sfNE50gIc0ydb1nwY9ZbQM83d-Q_9KaKsqY,33572
3
- adv_optm/optim/AdamW_adv.py,sha256=_Nb_hw01ALW3se869kYiOJEKOftlQLW04Dj7Wver6eo,22749
4
- adv_optm/optim/Adopt_adv.py,sha256=WpuuWuO1c8mOVKAaTAbsKUMfQ2IaZWKKRd9C-TOi4WY,22988
5
- adv_optm/optim/Lion_adv.py,sha256=owq4RUKjw-xfn5r_UOekTr5EGqpx8zCp369KkK9vll8,12596
6
- adv_optm/optim/Muon_adv.py,sha256=2EmsS74_UD4A4yidv1q5EOvzpmqNDVKUWyYXeLmipo0,27877
7
- adv_optm/optim/Prodigy_adv.py,sha256=292F0ckr0MYEGex34z3F2xaWAcPPGP-KdyZGEteP17M,26991
8
- adv_optm/optim/SignSGD_adv.py,sha256=GF3g-EuVvcGBrg00UhIstoJ27vDxgxWifGUR38Q6iKw,16012
9
- adv_optm/optim/SinkSGD_adv.py,sha256=BPJ5xMweqxFnf6pWePa7iBbY73oXE3I_67S-3nBProg,16118
10
- adv_optm/optim/__init__.py,sha256=RkpzWpEgAOdQAppTBLc5AilfJq-wn0aabNsQDj_7-4A,452
11
- adv_optm/util/Kourkoutas.py,sha256=tTo2QbOWPhI29hWQ4ERosoZXRAFXXWDzIMCj1KMFaOE,13325
12
- adv_optm/util/Muon_AuxAdam.py,sha256=le72chD-VhQoDYN5BwungO_3G_7nZHwR89X9WauISsY,8313
13
- adv_optm/util/Muon_util.py,sha256=w6f6aeBpuJwNhltXiVlJy8X52hlaGQACdPHFFN7-hJg,15280
14
- adv_optm/util/OrthoGrad.py,sha256=TXZENLTa66hvtdIgyldyRgZo25GgQQx5MCZTK1KvE5g,780
15
- adv_optm/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- adv_optm/util/centered_decay.py,sha256=F6IGOjrEaVS5HAvO67u1gcDcu8TLQINFHL5s9RHoIb8,3869
17
- adv_optm/util/factorization_util.py,sha256=I75o2Cxlp4XKMBBnuRa6nDWAT9aJqNS75zv7-OGO5K8,3974
18
- adv_optm/util/lion_k.py,sha256=b3oqTHK1spoD_Hi8wxKZ9lEWUkHejovT4piU-wUp_eI,1686
19
- adv_optm/util/param_update.py,sha256=cUubaDgurnu-j9nWlyPDmTx48fJQV727xR48VVL2vGk,18281
20
- adv_optm/util/scaled_optm.py,sha256=a10D2oZJjfhjx2I6FP71oZ0LLcZ0Sp-PYbqTVk1W6QE,8014
21
- adv_optm/util/signed_util.py,sha256=ebQn-O8pNvld3RvsJ_MAdCL4C8pBNS15Uqzr2xHChss,1873
22
- adv_optm/util/sinkhorn.py,sha256=oFdWY243jBeuBKLBS47I5dE8DsK5OqJNDWlMBKXsxUM,5075
23
- adv_optm/util/state_util.py,sha256=K3cF_HTDqU__lC4PYDSd16K0ITzcKPbsSXyZPFUyHSU,11199
24
- adv_optm/util/update_util.py,sha256=IAaYuxrJ9kP2P4Z0v02pesRe4mRzP5SvXbLHjLrrdZw,1332
25
- adv_optm-2.4.dev22.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
26
- adv_optm-2.4.dev22.dist-info/METADATA,sha256=S3Vvj3dsUSOF0km5FGq33CYu-2hGxdltQYMdbur1eZo,9725
27
- adv_optm-2.4.dev22.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
28
- adv_optm-2.4.dev22.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
29
- adv_optm-2.4.dev22.dist-info/RECORD,,