adv-optm 1.2.dev7__py3-none-any.whl → 1.2.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

adv_optm/__init__.py CHANGED
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "AdaMuon_adv",
21
21
  ]
22
22
 
23
- __version__ = "1.2.dev7"
23
+ __version__ = "1.2.dev9"
@@ -3,7 +3,6 @@ from typing import Optional, Callable
3
3
 
4
4
  from .AdamW_adv import AdamW_adv
5
5
  from ..util.MuonAdam_helper import MuonAdamHelper
6
- from ..util.Kourkoutas import KourkoutasHelper
7
6
 
8
7
  from ..util.BF16_Stochastic_Rounding import add_stochastic_
9
8
  from ..util.Newton_Schulz import _newton_schulz_iteration
@@ -64,22 +63,13 @@ class AdaMuon_adv(torch.optim.Optimizer):
64
63
  matrices for muon NewtonSchulz (default: False).
65
64
  vector_reshape (bool): whether to reshape 1D vectors into 2D
66
65
  matrices to apply low-rank compression (default: True).
66
+ low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
67
+ projects the update to a lower rank before orthogonalization.
68
+ (default: False)
69
+ ortho_rank (int): The rank for low-rank orthogonalization.
70
+ (default: 128)
67
71
  nnmf_factor (bool): whether to use the factorization or disable it to use
68
72
  the uncompressed optimizer. (default: False)
69
- kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
70
- If `False`, the optimizer behaves as standard AdamW. (default: False)
71
- beta2_min (float): The minimum value for dynamic β₂, used during periods of
72
- high gradient variance ("sunspikes"). Must be less than `betas[1]`.
73
- (default: 0.88)
74
- ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
75
- the pooled gradient norms. Corresponds to `α` in the paper.
76
- (default: 0.93)
77
- tiny_spike (float): A small constant added to the denominator of the
78
- "sunspike" ratio calculation to prevent division by zero. Corresponds
79
- to `ε_spike` in the paper. (default: 1e-9)
80
- k_warmup_steps (int): The number of initial steps during which β₂ is held
81
- at a fixed beta2 value before the
82
- dynamic logic activates. (default: 0)
83
73
  MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
84
74
  Parameters designated by `layer_key_fn` will be optimized with
85
75
  AdamW_adv instead of Muon. (default: False)
@@ -110,15 +100,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
110
100
  alpha_grad: float = 100.0,
111
101
  vector_reshape_muon: bool = False,
112
102
  vector_reshape: bool = False,
103
+ # Low-rank Muon
104
+ low_rank_ortho: bool = False,
105
+ ortho_rank: int = 128,
113
106
  nnmf_factor: bool = False,
114
- # K-b parameters
115
- kourkoutas_beta: bool = False,
116
- beta2_min: float = 0.9,
117
- ema_alpha: float = 0.95,
118
- tiny_spike: float = 1e-9,
119
- k_warmup_steps: int = 0,
120
- k_logging: int = 0,
121
- layer_key_kb_fn: Optional[Callable] = None,
122
107
  # hybrid optimizer mode
123
108
  MuonWithAuxAdam: bool = False,
124
109
  layer_key_fn: Optional[Callable] = None,
@@ -142,14 +127,11 @@ class AdaMuon_adv(torch.optim.Optimizer):
142
127
  "vector_reshape": vector_reshape,
143
128
  "vector_reshape_muon": vector_reshape_muon,
144
129
  "nesterov":nesterov, "use_atan2":use_atan2,
145
- "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
146
- "_kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
147
- "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
130
+ "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
131
+ # Low-rank Ortho
132
+ "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
148
133
  }
149
134
  self.stochastic_rounding = stochastic_rounding
150
- self._kourkoutas_beta = kourkoutas_beta
151
- self._kourkoutas_helper = None
152
- self.layer_key_kb_fn = layer_key_kb_fn
153
135
  self.MuonWithAuxAdam = MuonWithAuxAdam
154
136
  self.helper = None
155
137
  self.aux_adam = None
@@ -182,14 +164,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
182
164
 
183
165
  for key, value in defaults_to_use.items():
184
166
  new_group.setdefault(key, value)
185
- if '_kourkoutas_beta' not in new_group:
186
- if optim_type == 'adam':
187
- new_group['_kourkoutas_beta'] = False
188
- else:
189
- new_group['_kourkoutas_beta'] = muon_defaults['_kourkoutas_beta']
190
167
  final_param_groups.append(new_group)
191
168
 
192
- super().__init__(final_param_groups, {})
169
+ super().__init__(final_param_groups, muon_defaults)
193
170
 
194
171
  # Now that self is initialized, create the helper
195
172
  self.helper = MuonAdamHelper(self, layer_key_fn)
@@ -219,9 +196,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
219
196
 
220
197
  @torch.no_grad()
221
198
  def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
222
- if group['_kourkoutas_beta'] and self._kourkoutas_helper is None:
223
- self._kourkoutas_helper = KourkoutasHelper(self)
224
-
225
199
  if self.MuonWithAuxAdam:
226
200
  optim_type = self.helper.get_optimizer_type(p)
227
201
  if optim_type == 'adam':
@@ -277,7 +251,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
277
251
 
278
252
  # Retrieve hyperparameters
279
253
  beta1, beta2 = group['betas']
280
- current_step = state['step']
281
254
  nesterov = group['nesterov']
282
255
  Simplified_AdEMAMix = group['Simplified_AdEMAMix']
283
256
  alpha_grad = group['alpha_grad']
@@ -303,20 +276,37 @@ class AdaMuon_adv(torch.optim.Optimizer):
303
276
  signed_m_buf = torch.sign(mt_buf)
304
277
  del grad_reshaped
305
278
 
306
- update = _newton_schulz_iteration(
307
- signed_m_buf,
308
- steps=group['ns_steps'],
309
- eps=group['ns_eps'],
310
- coeffs=group['ns_coeffs'],
311
- )
312
-
313
- if group['_kourkoutas_beta']:
314
- # Call prepare_step() once at the beginning of the step for all params
315
- self._kourkoutas_helper.maybe_prepare_step(current_step)
316
- # Accumulate current sign-stabilized orthogonal update's norm for the *next* step
317
- self._kourkoutas_helper.accumulate_gradient_sq_norm(p, update.view(p.shape))
318
- # Get the dynamic beta2 calculated in prepare_step()
319
- beta2 = self._kourkoutas_helper.get_beta2(p, group, current_step)
279
+ # Orthogonalization step
280
+ if group['low_rank_ortho']:
281
+ # Low-Rank Orthogonalization on the reconstructed matrix
282
+ M = signed_m_buf
283
+ r = min(group['ortho_rank'], M.shape[0], M.shape[1])
284
+ if r > 0:
285
+ G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
286
+ MG = M @ G_sketch
287
+ if MG.dtype != torch.float32:
288
+ MG_dtype = M.dtype
289
+ Q, _ = torch.linalg.qr(MG.float())
290
+ Q = Q.to(MG_dtype)
291
+ else:
292
+ Q, _ = torch.linalg.qr(MG)
293
+ projected_M = Q.T @ M
294
+ ortho_projected_M = _newton_schulz_iteration(
295
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
296
+ )
297
+ update = Q @ ortho_projected_M
298
+ else: # Fallback for invalid rank
299
+ update = _newton_schulz_iteration(
300
+ signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
301
+ )
302
+ else:
303
+ # Original full Newton-Schulz
304
+ update = _newton_schulz_iteration(
305
+ signed_m_buf,
306
+ steps=group['ns_steps'],
307
+ eps=group['ns_eps'],
308
+ coeffs=group['ns_coeffs'],
309
+ )
320
310
 
321
311
  # Reconstruct second momentum from previous step's factors
322
312
  vt_buf = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
@@ -337,7 +327,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
337
327
  # RMS-aligned rescaling
338
328
  rms_target = group['rms_target']
339
329
  num_elements = update.numel()
340
- scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm())
330
+ # Add eps to prevent division by zero
331
+ scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm() + group['eps'])
341
332
 
342
333
  update.mul_(scaling_factor)
343
334
  update = update.view(p.shape).mul_(group['lr'])
@@ -381,25 +372,41 @@ class AdaMuon_adv(torch.optim.Optimizer):
381
372
  if len(p.shape) > 2:
382
373
  signed_m_buf = signed_m_buf.view(p.shape[0], -1)
383
374
 
384
- # NewtonSchulz
385
- update = _newton_schulz_iteration(
386
- signed_m_buf,
387
- steps=group['ns_steps'],
388
- eps=group['ns_eps'],
389
- coeffs=group['ns_coeffs'],
390
- )
375
+ # Orthogonalization step
376
+ if group['low_rank_ortho']:
377
+ # Low-Rank Orthogonalization on the reconstructed matrix
378
+ M = signed_m_buf
379
+ r = min(group['ortho_rank'], M.shape[0], M.shape[1])
380
+ if r > 0:
381
+ G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
382
+ MG = M @ G_sketch
383
+ if MG.dtype != torch.float32:
384
+ MG_dtype = M.dtype
385
+ Q, _ = torch.linalg.qr(MG.float())
386
+ Q = Q.to(MG_dtype)
387
+ else:
388
+ Q, _ = torch.linalg.qr(MG)
389
+ projected_M = Q.T @ M
390
+ ortho_projected_M = _newton_schulz_iteration(
391
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
392
+ )
393
+ update = Q @ ortho_projected_M
394
+ else: # Fallback for invalid rank
395
+ update = _newton_schulz_iteration(
396
+ signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
397
+ )
398
+ else:
399
+ # Original full Newton-Schulz
400
+ update = _newton_schulz_iteration(
401
+ signed_m_buf,
402
+ steps=group['ns_steps'],
403
+ eps=group['ns_eps'],
404
+ coeffs=group['ns_coeffs'],
405
+ )
391
406
 
392
407
  if len(p.shape) > 2 or state['reshaped_1d_muon']:
393
408
  update = update.view(p.shape)
394
409
 
395
- if group['_kourkoutas_beta']:
396
- # Call prepare_step() once at the beginning of the step for all params
397
- self._kourkoutas_helper.maybe_prepare_step(current_step)
398
- # Accumulate current sign-stabilized orthogonal update's norm for the *next* step
399
- self._kourkoutas_helper.accumulate_gradient_sq_norm(p, update)
400
- # Get the dynamic beta2 calculated in prepare_step()
401
- beta2 = self._kourkoutas_helper.get_beta2(p, group, current_step)
402
-
403
410
  vt_buf = state['second_momentum_buffer']
404
411
  vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
405
412
 
@@ -416,7 +423,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
416
423
  # RMS-aligned rescaling
417
424
  rms_target = group['rms_target']
418
425
  num_elements = update.numel()
419
- scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm())
426
+ # Add eps to prevent division by zero
427
+ scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm() + group['eps'])
420
428
 
421
429
  update.mul_(scaling_factor)
422
430
  del num_elements, scaling_factor
@@ -178,7 +178,7 @@ class Muon_adv(torch.optim.Optimizer):
178
178
 
179
179
  final_param_groups.append(new_group)
180
180
 
181
- super().__init__(final_param_groups, {})
181
+ super().__init__(final_param_groups, muon_defaults)
182
182
 
183
183
  # Now that self is initialized, create the helper
184
184
  self.helper = MuonAdamHelper(self, layer_key_fn)
@@ -292,10 +292,10 @@ class Muon_adv(torch.optim.Optimizer):
292
292
  del grad_reshaped
293
293
 
294
294
  # Orthogonalization step
295
- if group['low_rank_muon']:
295
+ if group['low_rank_ortho']:
296
296
  # Low-Rank Orthogonalization on the reconstructed matrix
297
297
  M = update
298
- r = min(group['low_rank_rank'], M.shape[0], M.shape[1])
298
+ r = min(group['ortho_rank'], M.shape[0], M.shape[1])
299
299
  if r > 0:
300
300
  G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
301
301
  MG = M @ G_sketch
@@ -94,7 +94,7 @@ class KourkoutasHelper:
94
94
  for layer_key, info in self.layer_info.items():
95
95
  params, group = info['params'], info['group_ref']
96
96
 
97
- if not group.get('kourkoutas_beta', False):
97
+ if not group.get('kourkoutas_beta', False) and not group.get('_kourkoutas_beta', False):
98
98
  continue
99
99
 
100
100
  first_param_in_layer = info['params'][0]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev7
3
+ Version: 1.2.dev9
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -1,24 +1,24 @@
1
- adv_optm/__init__.py,sha256=93-4akpaONvZ7BCkwDSYs3i28lI-aFV1RVwsEq1UZhU,379
2
- adv_optm/optim/AdaMuon_adv.py,sha256=hTGSH8wzmQ-NYIcqV6EAEbqCxxfEwmmMWaIadX1qiuQ,21009
1
+ adv_optm/__init__.py,sha256=TzvKgGTLkK0_XANeZzhURcSO9xmtUi-H9_C7tV3rXn4,379
2
+ adv_optm/optim/AdaMuon_adv.py,sha256=yr1oJV339Zv7D8n148O1FJJAgdOsH8NZDZTKlcDOyu0,21181
3
3
  adv_optm/optim/AdamW_adv.py,sha256=7IvdD1rqYeHZwQCZU9X0H7x87MCKcHQ5M68GLuMCkvE,17702
4
4
  adv_optm/optim/Adopt_adv.py,sha256=C2FsEZGvCk9q4YNKAj0qIxdZ5AfPlda-1lIpSX0a1nE,21256
5
5
  adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
6
6
  adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
7
- adv_optm/optim/Muon_adv.py,sha256=JBLLfU83lRwezowI6A4JQAO1-NBLvSDOB8Dsad5zuHU,22775
7
+ adv_optm/optim/Muon_adv.py,sha256=HaF06fPKcKpVZY29_vqjWHAfivjvGntBuRyDDKj3Ozw,22784
8
8
  adv_optm/optim/Prodigy_adv.py,sha256=bmwuO8GrJHH4NaEaqE-ffcR9wHhQ57457xoN-P6hyks,25909
9
9
  adv_optm/optim/Simplified_AdEMAMix.py,sha256=sY-vThMVgADRh0ar9WHkrM2n8UcgQLQC1YV1Wx8uFz4,12983
10
10
  adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
11
11
  adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
12
12
  adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
13
- adv_optm/util/Kourkoutas.py,sha256=WPAjxaH9pGVtLK_QJcwjkJOnN02Hfyu0F2T90hbhtqo,9662
13
+ adv_optm/util/Kourkoutas.py,sha256=lObJGXmz3MqGSuu3DKqotSpZ0fuQFPE80R3zO_j3Z_Q,9707
14
14
  adv_optm/util/MuonAdam_helper.py,sha256=7rnNMujZVDaqo1g22QscMyPlZvIHQQSLHMED9_I8QWU,1250
15
15
  adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
16
16
  adv_optm/util/Newton_Schulz.py,sha256=wJ_sKRaGVIsOofQ737my4ng494qX_pfgOqlDDmYtnCg,1377
17
17
  adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
18
18
  adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
19
19
  adv_optm/util/__init__.py,sha256=jAaUfaAjFrTJ6-Q915ezAbq0efRbpYjriW2OdeCbSzo,433
20
- adv_optm-1.2.dev7.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
21
- adv_optm-1.2.dev7.dist-info/METADATA,sha256=mAEVDwu_gh6S-fN6LBfEJoYdn_5LJLOw_nHRZcE7orw,14022
22
- adv_optm-1.2.dev7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- adv_optm-1.2.dev7.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
24
- adv_optm-1.2.dev7.dist-info/RECORD,,
20
+ adv_optm-1.2.dev9.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
21
+ adv_optm-1.2.dev9.dist-info/METADATA,sha256=GmAYWjZdfgvg9QbzyiV2PUNmzQFgJz8AjaY5F0x7Nv8,14022
22
+ adv_optm-1.2.dev9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ adv_optm-1.2.dev9.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
24
+ adv_optm-1.2.dev9.dist-info/RECORD,,