adv-optm 2.5.9__tar.gz → 2.6.1.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/PKG-INFO +1 -1
  2. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/AdaMuon_adv.py +2 -1
  4. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/AdamW_adv.py +2 -1
  5. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/Adopt_adv.py +2 -1
  6. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/Lion_adv.py +2 -0
  7. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/Muon_adv.py +2 -1
  8. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/Prodigy_adv.py +2 -1
  9. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/SignSGD_adv.py +2 -0
  10. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/SinkSGD_adv.py +2 -1
  11. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/OrthoGrad.py +1 -1
  12. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/param_update.py +33 -27
  13. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/scaled_optm.py +24 -37
  14. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm.egg-info/PKG-INFO +1 -1
  15. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/setup.py +1 -1
  16. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/LICENSE +0 -0
  17. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/README.md +0 -0
  18. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/__init__.py +0 -0
  19. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/Kourkoutas.py +0 -0
  20. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/Muon_AuxAdam.py +0 -0
  21. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/Muon_util.py +0 -0
  22. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/__init__.py +0 -0
  23. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/centered_decay.py +0 -0
  24. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/factorization_util.py +0 -0
  25. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/lion_k.py +0 -0
  26. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/signed_util.py +0 -0
  27. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/sinkhorn.py +0 -0
  28. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/state_util.py +0 -0
  29. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/update_util.py +0 -0
  30. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm.egg-info/SOURCES.txt +0 -0
  31. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm.egg-info/dependency_links.txt +0 -0
  32. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm.egg-info/requires.txt +0 -0
  33. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm.egg-info/top_level.txt +0 -0
  34. {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.5.9
3
+ Version: 2.6.1.dev2
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "SinkSGD_adv",
21
21
  ]
22
22
 
23
- __version__ = "2.5.9"
23
+ __version__ = "2.6.1.dev2"
@@ -137,6 +137,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
137
137
  # Decoupled/cautious weight decay
138
138
  weight_decay: float = 0,
139
139
  cautious_wd: bool = False,
140
+ scaled_wd: bool = False,
140
141
  # Nesterov momentum
141
142
  nesterov: bool = True,
142
143
  nesterov_coef: float | None = None,
@@ -227,7 +228,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
227
228
 
228
229
  defaults = {
229
230
  "lr": lr, "betas": betas, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
230
- "eps": eps, "rms_rescaling": rms_rescaling, "ns_steps": ns_steps,
231
+ "eps": eps, "rms_rescaling": rms_rescaling, "ns_steps": ns_steps, "scaled_wd": scaled_wd,
231
232
  "ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
232
233
  "vector_reshape": vector_reshape,
233
234
  "nesterov":nesterov, "nesterov_coef": nesterov_coef, "use_atan2":use_atan2,
@@ -98,6 +98,7 @@ class AdamW_adv(torch.optim.Optimizer):
98
98
  weight_decay: float = 0.0,
99
99
  fisher_wd: bool = False,
100
100
  cautious_wd: bool = False,
101
+ scaled_wd: bool = False,
101
102
  # Adam's Bias Correction
102
103
  use_bias_correction: bool = True,
103
104
  # Stochastic Rounding for BF16
@@ -156,7 +157,7 @@ class AdamW_adv(torch.optim.Optimizer):
156
157
 
157
158
  defaults = {
158
159
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
159
- "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
160
+ "fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "scaled_wd": scaled_wd,
160
161
  "use_atan2": use_atan2, "nesterov": nesterov, "nesterov_coef": nesterov_coef,
161
162
  "normed_momentum": normed_momentum,
162
163
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
@@ -101,6 +101,7 @@ class Adopt_adv(torch.optim.Optimizer):
101
101
  weight_decay: float = 0.0,
102
102
  fisher_wd: bool = False,
103
103
  cautious_wd: bool = False,
104
+ scaled_wd: bool = False,
104
105
  # ADOPT clipping
105
106
  clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
106
107
  # Adam_atan2 (scale invariant)
@@ -157,7 +158,7 @@ class Adopt_adv(torch.optim.Optimizer):
157
158
  state_precision = "factored"
158
159
 
159
160
  defaults = {
160
- "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
161
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "scaled_wd": scaled_wd,
161
162
  "fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "orthogonal_gradient": orthogonal_gradient,
162
163
  "nesterov": nesterov, "nesterov_coef": nesterov_coef,
163
164
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
@@ -64,6 +64,7 @@ class Lion_adv(torch.optim.Optimizer):
64
64
  # Decoupled/cautious weight decay
65
65
  weight_decay: float = 0.0,
66
66
  cautious_wd: bool = False,
67
+ scaled_wd: bool = False,
67
68
  # Stochastic Rounding for BF16
68
69
  stochastic_rounding: bool = True,
69
70
  # OrthoGrad
@@ -96,6 +97,7 @@ class Lion_adv(torch.optim.Optimizer):
96
97
  betas=betas,
97
98
  weight_decay=weight_decay,
98
99
  cautious_wd=cautious_wd,
100
+ scaled_wd=scaled_wd,
99
101
  vector_reshape=vector_reshape,
100
102
  orthogonal_gradient=orthogonal_gradient,
101
103
  kappa_p=kappa_p,
@@ -111,6 +111,7 @@ class Muon_adv(torch.optim.Optimizer):
111
111
  # Decoupled/cautious weight decay
112
112
  weight_decay: float = 0.0,
113
113
  cautious_wd: bool = False,
114
+ scaled_wd: bool = False,
114
115
  # Nesterov momentum
115
116
  nesterov: bool = True,
116
117
  nesterov_coef: float | None = None,
@@ -201,7 +202,7 @@ class Muon_adv(torch.optim.Optimizer):
201
202
  defaults = {
202
203
  "lr": lr, "beta1": beta1, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
203
204
  "nesterov": nesterov, "nesterov_coef": nesterov_coef, "ns_steps": ns_steps, "ns_eps": ns_eps,
204
- "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
205
+ "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor, "scaled_wd": scaled_wd,
205
206
  "vector_reshape": vector_reshape, "rms_rescaling": rms_rescaling,
206
207
  "orthogonal_gradient": orthogonal_gradient,
207
208
  'compiled_optimizer': compiled_optimizer,
@@ -115,6 +115,7 @@ class Prodigy_adv(torch.optim.Optimizer):
115
115
  weight_decay: float = 0.0,
116
116
  fisher_wd: bool = False,
117
117
  cautious_wd: bool = False,
118
+ scaled_wd: bool = False,
118
119
  # Stochastic Rounding for BF16
119
120
  stochastic_rounding: bool = True,
120
121
  # Adam_atan2 (scale invariant)
@@ -181,7 +182,7 @@ class Prodigy_adv(torch.optim.Optimizer):
181
182
 
182
183
  defaults = {
183
184
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
184
- "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
185
+ "fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "scaled_wd": scaled_wd,
185
186
  "use_atan2": use_atan2,
186
187
  "orthogonal_gradient": orthogonal_gradient,
187
188
  "compiled_optimizer": compiled_optimizer,
@@ -59,6 +59,7 @@ class SignSGD_adv(torch.optim.Optimizer):
59
59
  # weight decay features
60
60
  geometric_wd: bool = False,
61
61
  cautious_wd: bool = False,
62
+ scaled_wd: bool = False,
62
63
  # Stochastic Rounding for BF16
63
64
  stochastic_rounding: bool = True,
64
65
  # OrthoGrad
@@ -108,6 +109,7 @@ class SignSGD_adv(torch.optim.Optimizer):
108
109
  momentum=momentum,
109
110
  weight_decay=weight_decay,
110
111
  cautious_wd=cautious_wd,
112
+ scaled_wd=scaled_wd,
111
113
  geometric_wd=geometric_wd,
112
114
  vector_reshape=vector_reshape,
113
115
  orthogonal_gradient=orthogonal_gradient,
@@ -66,6 +66,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
66
66
  # weight decay features
67
67
  geometric_wd: bool = False,
68
68
  cautious_wd: bool = False,
69
+ scaled_wd: bool = False,
69
70
  # Stochastic Rounding for BF16
70
71
  stochastic_rounding: bool = True,
71
72
  # OrthoGrad
@@ -103,7 +104,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
103
104
  defaults = {
104
105
  "lr": lr, "momentum": momentum,
105
106
  "weight_decay": weight_decay, "nesterov": nesterov, "nesterov_coef": nesterov_coef, "normed_momentum": normed_momentum, "snr_cond": snr_cond,
106
- "geometric_wd": geometric_wd, "cautious_wd": cautious_wd,
107
+ "geometric_wd": geometric_wd, "cautious_wd": cautious_wd, "scaled_wd": scaled_wd,
107
108
  "orthogonal_gradient": orthogonal_gradient,
108
109
  "compiled_optimizer": compiled_optimizer,
109
110
  "sinkhorn_iterations": sinkhorn_iterations,
@@ -43,7 +43,7 @@ def iterative_ortho_project(p: torch.Tensor, grad: torch.Tensor, iters: int = 3)
43
43
  # 1D Vector Case fallback to the standard OrthoGrad
44
44
  is_vector = p.ndim < 2 or getattr(p, '_is_dora_scale', False) or getattr(p, 'is_vector', False)
45
45
  if is_vector:
46
- return _orthogonalize_gradient(p, grad)
46
+ return flattened_ortho_project(p, grad)
47
47
 
48
48
  original_shape = grad.shape
49
49
 
@@ -6,7 +6,7 @@ import torch.nn.functional as F
6
6
 
7
7
  from typing import Dict, Any
8
8
 
9
- from .scaled_optm import adjust_wds
9
+ from .scaled_optm import adjust_wds, scale_wd
10
10
  from .centered_decay import dequantize_anchor
11
11
 
12
12
  _generators: Dict[torch.device, torch.Generator] = {}
@@ -18,8 +18,8 @@ def _apply_weight_decay(
18
18
  p: Tensor,
19
19
  state: Dict[str, Any],
20
20
  group: Dict[str, Any],
21
- scaled_wd: float | Tensor | None,
22
- scaled_cwd: float | Tensor | None,
21
+ eff_wd: float | Tensor | None,
22
+ eff_cwd: float | Tensor | None,
23
23
  wd_target: Tensor | None = None,
24
24
  cwd_target: Tensor | None = None,
25
25
  ) -> None:
@@ -29,26 +29,26 @@ def _apply_weight_decay(
29
29
  cautious = group.get('cautious_wd', False)
30
30
 
31
31
  # Standard Weight Decay (pulls toward zero)
32
- if scaled_wd is not None:
32
+ if eff_wd is not None:
33
33
  if wd_target is None:
34
34
  wd_target = p_calc
35
35
  # Cautious Weight Decay: only decay if the update pushes in the same direction as the decay
36
36
  if cautious:
37
37
  mask = (update_calc * p_calc >= 0).to(p_calc.dtype)
38
- if isinstance(scaled_wd, Tensor):
39
- p_calc.addcmul_(wd_target, mask * scaled_wd, value=-1.0)
38
+ if isinstance(eff_wd, Tensor):
39
+ p_calc.addcmul_(wd_target, mask * eff_wd, value=-1.0)
40
40
  else:
41
- p_calc.addcmul_(wd_target, mask, value=-scaled_wd)
41
+ p_calc.addcmul_(wd_target, mask, value=-eff_wd)
42
42
  del mask
43
43
  else:
44
44
  # Standard decoupled weight decay
45
- if isinstance(scaled_wd, Tensor):
46
- p_calc.addcmul_(wd_target, scaled_wd, value=-1.0)
45
+ if isinstance(eff_wd, Tensor):
46
+ p_calc.addcmul_(wd_target, eff_wd, value=-1.0)
47
47
  else:
48
- p_calc.add_(wd_target, alpha=-scaled_wd)
48
+ p_calc.add_(wd_target, alpha=-eff_wd)
49
49
 
50
50
  # Centered Weight Decay (pulls toward anchor)
51
- if scaled_cwd is not None and 'anchor_data' in state:
51
+ if eff_cwd is not None and 'anchor_data' in state:
52
52
  if cwd_target is not None:
53
53
  decay_target = cwd_target
54
54
  else:
@@ -59,17 +59,17 @@ def _apply_weight_decay(
59
59
  if cautious:
60
60
  # Cautious Weight Decay: only decay if the update pushes in the same direction as the decay
61
61
  mask = (update_calc * decay_target >= 0).to(p_calc.dtype)
62
- if isinstance(scaled_cwd, Tensor):
63
- p_calc.addcmul_(decay_target, mask * scaled_cwd, value=-1.0)
62
+ if isinstance(eff_cwd, Tensor):
63
+ p_calc.addcmul_(decay_target, mask * eff_cwd, value=-1.0)
64
64
  else:
65
- p_calc.addcmul_(decay_target, mask, value=-scaled_cwd)
65
+ p_calc.addcmul_(decay_target, mask, value=-eff_cwd)
66
66
  del mask
67
67
  else:
68
68
  # Standard decoupled weight decay
69
- if isinstance(scaled_cwd, Tensor):
70
- p_calc.addcmul_(decay_target, scaled_cwd, value=-1.0)
69
+ if isinstance(eff_cwd, Tensor):
70
+ p_calc.addcmul_(decay_target, eff_cwd, value=-1.0)
71
71
  else:
72
- p_calc.add_(decay_target, alpha=-scaled_cwd)
72
+ p_calc.add_(decay_target, alpha=-eff_cwd)
73
73
 
74
74
  if cwd_target is None:
75
75
  del decay_target
@@ -105,18 +105,24 @@ def apply_parameter_update(
105
105
  wd = group["weight_decay"] if wd is None else wd
106
106
  cwd = group.get("centered_wd", 0.0)
107
107
  wd, cwd = adjust_wds(wd, cwd, p)
108
+ scaled_wd = group.get("scaled_wd", False)
109
+ decoupled = scaled_wd
108
110
 
109
111
  # Calculate global decay factor for decoupled vs standard
110
112
  decay_factor = (lr / self._init_lr) if decoupled else lr
111
113
 
112
- scaled_wd = (wd * decay_factor) if wd != 0 else None
113
- scaled_cwd = (cwd * decay_factor) if cwd != 0 else None
114
+ eff_wd = (wd * decay_factor) if wd != 0 else None
115
+ eff_cwd = (cwd * decay_factor) if cwd != 0 else None
114
116
 
115
117
  if wd_scaler is not None:
116
- if scaled_wd is not None:
117
- scaled_wd = scaled_wd * wd_scaler
118
- if scaled_cwd is not None:
119
- scaled_cwd = scaled_cwd * wd_scaler
118
+ if eff_wd is not None:
119
+ if scaled_wd:
120
+ eff_wd = scale_wd(eff_wd, p)
121
+ eff_wd = eff_wd * wd_scaler
122
+ if eff_cwd is not None:
123
+ if scaled_wd:
124
+ eff_cwd = scale_wd(eff_cwd, p)
125
+ eff_cwd = eff_cwd * wd_scaler
120
126
 
121
127
  state = self.state[p]
122
128
 
@@ -129,8 +135,8 @@ def apply_parameter_update(
129
135
  cwd_t = cwd_target.float() if cwd_target is not None else None
130
136
 
131
137
  # Apply weight decay if needed
132
- if scaled_wd is not None or scaled_cwd is not None:
133
- _apply_weight_decay(p_fp32, update_fp32, p, state, group, scaled_wd, scaled_cwd, wd_t, cwd_t)
138
+ if eff_wd is not None or eff_cwd is not None:
139
+ _apply_weight_decay(p_fp32, update_fp32, p, state, group, eff_wd, eff_cwd, wd_t, cwd_t)
134
140
 
135
141
  # Apply main update
136
142
  p_fp32.add_(-update_fp32)
@@ -147,8 +153,8 @@ def apply_parameter_update(
147
153
 
148
154
  else:
149
155
  # Standard path for non-bfloat16 or without stochastic rounding
150
- if scaled_wd is not None or scaled_cwd is not None:
151
- _apply_weight_decay(p, update, p, state, group, scaled_wd, scaled_cwd, wd_target, cwd_target)
156
+ if eff_wd is not None or eff_cwd is not None:
157
+ _apply_weight_decay(p, update, p, state, group, eff_wd, eff_cwd, wd_target, cwd_target)
152
158
 
153
159
  # Apply main update
154
160
  p.add_(-update)
@@ -7,14 +7,14 @@ import math
7
7
  _OFT_INDICES_CACHE = {}
8
8
  _OFT_IDENTITY_CACHE = {}
9
9
 
10
- def get_cached_structural_tensors(b: int, dtype: torch.dtype, device: torch.device):
10
+ def get_cached_structural_tensors(b: int, device: torch.device):
11
11
  """
12
- Retrieves or creates structural tensors (indices and Identity) for OFT exact geometry.
12
+ Retrieves or creates structural tensors (indices) for OFT exact geometry.
13
13
  Caches them globally to prevent redundant memory allocation across thousands of layers.
14
14
  """
15
- global _OFT_INDICES_CACHE, _OFT_IDENTITY_CACHE
15
+ global _OFT_INDICES_CACHE
16
16
 
17
- # Cache for Indices (Dtype independent, only depends on block size and device)
17
+ # Cache for Indices
18
18
  idx_key = (b, device)
19
19
  if idx_key not in _OFT_INDICES_CACHE:
20
20
  rows, cols = torch.triu_indices(b, b, 1, device=device)
@@ -22,15 +22,8 @@ def get_cached_structural_tensors(b: int, dtype: torch.dtype, device: torch.devi
22
22
  else:
23
23
  rows, cols = _OFT_INDICES_CACHE[idx_key]
24
24
 
25
- # Cache for Identity Matrix (Depends on block size, dtype, and device)
26
- id_key = (b, dtype, device)
27
- if id_key not in _OFT_IDENTITY_CACHE:
28
- I = torch.eye(b, dtype=dtype, device=device).unsqueeze(0)
29
- _OFT_IDENTITY_CACHE[id_key] = I
30
- else:
31
- I = _OFT_IDENTITY_CACHE[id_key]
32
25
 
33
- return rows, cols, I
26
+ return rows, cols
34
27
 
35
28
  def scale_update(
36
29
  p: torch.Tensor,
@@ -59,7 +52,7 @@ def scale_update(
59
52
  return max_abs_normalization(update, dim=None, lr=lr)
60
53
 
61
54
  # OFT Block Parameters: shape (k, C(b,2))
62
- # Direct spectral normalization on the skew-symmetric blocks, followed by Riemannian preconditioning.
55
+ # Direct spectral normalization on the skew-symmetric blocks.
63
56
  if is_oft:
64
57
  return apply_spectral_riemannian_oft(p, update, lr, state)
65
58
 
@@ -104,6 +97,20 @@ def adjust_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
104
97
  # Centered WD safely regularizes the delta without collapsing base feature variance.
105
98
  return wd, cwd
106
99
 
100
+ def scale_wd(wd: float, p: torch.Tensor) -> float:
101
+ """
102
+ Scale-invariant, dimension-scaled weight decay.
103
+ """
104
+ if getattr(p, '_is_oft', False):
105
+ n_el = p.shape[-1]
106
+ b = (1.0 + math.sqrt(1.0 + 8.0 * n_el)) / 2.0
107
+ wd = (2 * wd) / (b - 1)
108
+ return wd
109
+
110
+ if p.ndim >= 2:
111
+ width = p.numel() // p.shape[0]
112
+ return wd / width
113
+
107
114
 
108
115
  def is_spectral(p: torch.Tensor) -> bool:
109
116
  """Determines if a parameter should undergo spectral normalization updates."""
@@ -176,34 +183,26 @@ def apply_spectral_riemannian_oft(
176
183
  state: dict
177
184
  ) -> torch.Tensor:
178
185
  """
179
- Applies Spectral Normalization directly on the skew-symmetric gradient,
180
- then uses True Matrix Preconditioning: M @ G @ M where M = (I - Q^2).
181
- Neutralizes the derivative shrinkage of the Cayley transform.
186
+ Applies Spectral Normalization directly on the skew-symmetric gradient.
182
187
  """
183
188
  n_el = p.shape[-1]
184
189
  block_size = int((1 + math.sqrt(1 + 8 * n_el)) / 2)
185
190
  device, dtype = p.device, p.dtype
186
- rows, cols, I = get_cached_structural_tensors(block_size, dtype, device)
191
+ rows, cols = get_cached_structural_tensors(block_size, device)
187
192
 
188
193
  # Flatten any prepended batch dimensions for processing
189
194
  orig_shape = p.shape
190
195
 
191
196
  # Align the scale of p with the forward pass
192
197
  scale_factor = getattr(p, '_oft_scale_factor', 1.0)
193
- p_flat = p.view(-1, n_el) / scale_factor
194
198
 
195
199
  update_flat = update.view(-1, n_el)
196
- batch_size = p_flat.shape[0]
200
+ batch_size = update_flat.shape[0]
197
201
 
198
202
  # Initialize matrices
199
- Q = torch.zeros(batch_size, block_size, block_size, device=device, dtype=dtype)
200
203
  G = torch.zeros(batch_size, block_size, block_size, device=device, dtype=dtype)
201
204
  batch_idx = torch.arange(batch_size, device=device)[:, None]
202
205
 
203
- # Construct skew-symmetric parameter matrix Q
204
- Q = Q.index_put((batch_idx, rows, cols), p_flat)
205
- Q = Q - Q.transpose(-2, -1)
206
-
207
206
  # Construct skew-symmetric gradient matrix G
208
207
  G = G.index_put((batch_idx, rows, cols), update_flat)
209
208
  G = G - G.transpose(-2, -1)
@@ -235,21 +234,9 @@ def apply_spectral_riemannian_oft(
235
234
  target_scale = 0.5 * scale_factor
236
235
  spectral_eps = 1.0 / (2.0 * math.sqrt(block_size))
237
236
 
238
- # Rescale G
239
237
  scale = lr * (target_scale / max_sigma.clamp_min(spectral_eps))
240
- G = G * scale
241
-
242
- # Apply Riemannian Preconditioning
243
- # Compute True Matrix Preconditioner M = I - Q^2
244
- M = I - torch.bmm(Q, Q)
245
-
246
- # Apply exact preconditioning: G_prec = M @ G @ M
247
- G_prec = torch.bmm(torch.bmm(M, G), M)
248
-
249
- # Extract the preconditioned upper-triangular elements
250
- update_prec_flat = G_prec[batch_idx, rows, cols]
251
238
 
252
- return update_prec_flat.view(orig_shape)
239
+ return update_flat.mul_(scale).view(orig_shape)
253
240
 
254
241
 
255
242
  @torch.no_grad()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.5.9
3
+ Version: 2.6.1.dev2
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.5.9",
8
+ version="2.6.1.dev2",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes