adv-optm 2.5.7__tar.gz → 2.6.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/PKG-INFO +1 -1
  2. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/AdaMuon_adv.py +3 -3
  4. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/AdamW_adv.py +3 -1
  5. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/Adopt_adv.py +3 -1
  6. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/Lion_adv.py +3 -0
  7. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/Muon_adv.py +3 -3
  8. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/Prodigy_adv.py +6 -1
  9. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/SignSGD_adv.py +3 -0
  10. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/SinkSGD_adv.py +3 -1
  11. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/Kourkoutas.py +0 -28
  12. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/Muon_AuxAdam.py +1 -0
  13. adv_optm-2.6.dev1/adv_optm/util/msign.py +114 -0
  14. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/param_update.py +12 -0
  15. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm.egg-info/PKG-INFO +1 -1
  16. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm.egg-info/SOURCES.txt +1 -0
  17. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/setup.py +1 -1
  18. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/LICENSE +0 -0
  19. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/README.md +0 -0
  20. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/__init__.py +0 -0
  21. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/Muon_util.py +0 -0
  22. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/OrthoGrad.py +0 -0
  23. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/__init__.py +0 -0
  24. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/centered_decay.py +0 -0
  25. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/factorization_util.py +0 -0
  26. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/lion_k.py +0 -0
  27. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/scaled_optm.py +0 -0
  28. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/signed_util.py +0 -0
  29. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/sinkhorn.py +0 -0
  30. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/state_util.py +0 -0
  31. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/update_util.py +0 -0
  32. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm.egg-info/dependency_links.txt +0 -0
  33. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm.egg-info/requires.txt +0 -0
  34. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm.egg-info/top_level.txt +0 -0
  35. {adv_optm-2.5.7 → adv_optm-2.6.dev1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.5.7
3
+ Version: 2.6.dev1
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "SinkSGD_adv",
21
21
  ]
22
22
 
23
- __version__ = "2.5.7"
23
+ __version__ = "2.6.dev1"
@@ -106,7 +106,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
106
106
  low-rank compression while keeping the first moment (momentum_buffer)
107
107
  dense. Ignored when `nnmf_factor=True` (full SMMF) or `normuon_variant=True`.
108
108
  Combines well with `state_precision` on the first moment. (default: False)
109
- n_layers (int): The depth of the network (L). Required for optimal epsilon scaling. (default: 1)
110
109
  spectral_normalization (bool): Enable explicit spectral normalization using power iteration. (default: False)
111
110
  --- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
112
111
  adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
@@ -177,8 +176,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
177
176
  approx_mars: bool = False,
178
177
  mars_gamma: float = 0.025,
179
178
  # Spectral Normalization
180
- n_layers: int = 1,
181
179
  spectral_normalization: bool = False,
180
+ # Orthogonalize the weights (Matrix Sign - MSign) every x steps
181
+ MSign_interval: int | None = None,
182
182
  # Centered WD
183
183
  centered_wd: float = 0.0,
184
184
  centered_wd_mode: str = 'float8',
@@ -249,7 +249,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
249
249
  # MARS-M
250
250
  "approx_mars": approx_mars, "mars_gamma": mars_gamma,
251
251
  # Spectral Normalization
252
- "n_layers": n_layers, "spectral_normalization": spectral_normalization,
252
+ "spectral_normalization": spectral_normalization, "MSign_interval": MSign_interval,
253
253
  # Centered WD
254
254
  "centered_wd": centered_wd,
255
255
  "centered_wd_mode": centered_wd_mode,
@@ -121,6 +121,8 @@ class AdamW_adv(torch.optim.Optimizer):
121
121
  layer_key_fn: Optional[Callable] = None,
122
122
  # Spectral Normed Optimizer
123
123
  spectral_normalization: bool = False,
124
+ # Orthogonalize the weights (Matrix Sign - MSign) every x steps
125
+ MSign_interval: int | None = None,
124
126
  # Centered WD
125
127
  centered_wd: float = 0.0,
126
128
  centered_wd_mode: str = 'float8',
@@ -163,7 +165,7 @@ class AdamW_adv(torch.optim.Optimizer):
163
165
  "compiled_optimizer": compiled_optimizer,
164
166
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
165
167
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
166
- "spectral_normalization": spectral_normalization,
168
+ "spectral_normalization": spectral_normalization, "MSign_interval": MSign_interval,
167
169
  "centered_wd": centered_wd, "centered_wd_mode": centered_wd_mode,
168
170
  "state_precision": state_precision,
169
171
  "nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd
@@ -122,6 +122,8 @@ class Adopt_adv(torch.optim.Optimizer):
122
122
  layer_key_fn: Optional[Callable] = None,
123
123
  # Spectral Normed Optimizer
124
124
  spectral_normalization: bool = False,
125
+ # Orthogonalize the weights (Matrix Sign - MSign) every x steps
126
+ MSign_interval: int | None = None,
125
127
  # Centered WD
126
128
  centered_wd: float = 0.0,
127
129
  centered_wd_mode: str = 'float8',
@@ -162,7 +164,7 @@ class Adopt_adv(torch.optim.Optimizer):
162
164
  "nesterov": nesterov, "nesterov_coef": nesterov_coef,
163
165
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
164
166
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
165
- "spectral_normalization": spectral_normalization,
167
+ "spectral_normalization": spectral_normalization, "MSign_interval": MSign_interval,
166
168
  "centered_wd": centered_wd,
167
169
  "centered_wd_mode": centered_wd_mode,
168
170
  "state_precision": state_precision,
@@ -78,6 +78,8 @@ class Lion_adv(torch.optim.Optimizer):
78
78
  centered_wd_mode: str = 'float8',
79
79
  # Spectral Normed Optimizer
80
80
  spectral_normalization: bool = False,
81
+ # Orthogonalize the weights (Matrix Sign - MSign) every x steps
82
+ MSign_interval: int | None = None,
81
83
  # SMMF factorization
82
84
  nnmf_factor: bool = False,
83
85
  vector_reshape: bool = False,
@@ -102,6 +104,7 @@ class Lion_adv(torch.optim.Optimizer):
102
104
  auto_kappa_p=auto_kappa_p,
103
105
  stochastic_sign=stochastic_sign,
104
106
  spectral_normalization=spectral_normalization,
107
+ MSign_interval=MSign_interval,
105
108
  nnmf_factor=nnmf_factor,
106
109
  centered_wd= centered_wd,
107
110
  centered_wd_mode= centered_wd_mode,
@@ -81,7 +81,6 @@ class Muon_adv(torch.optim.Optimizer):
81
81
  'float8': Uses torch.float8_e4m3fn for a balance of precision and memory.
82
82
  'int8': Uses 8-bit block-wise quantization (block size 128).
83
83
  'int4': Uses 4-bit block-wise quantization (block size 32).
84
- n_layers (int): The depth of the network (L). Required for optimal epsilon scaling. (default: 1)
85
84
  spectral_normalization (bool): Enable explicit spectral normalization using power iteration. (default: False)
86
85
  --- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
87
86
  adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
@@ -146,8 +145,9 @@ class Muon_adv(torch.optim.Optimizer):
146
145
  approx_mars: bool = False,
147
146
  mars_gamma: float = 0.025,
148
147
  # Spectral Normalization
149
- n_layers: int = 1,
150
148
  spectral_normalization: bool = False,
149
+ # Orthogonalize the weights (Matrix Sign - MSign) every x steps
150
+ MSign_interval: int | None = None,
151
151
  # Centered WD
152
152
  centered_wd: float = 0.0,
153
153
  centered_wd_mode: str = 'float8',
@@ -220,7 +220,7 @@ class Muon_adv(torch.optim.Optimizer):
220
220
  # MARS-M
221
221
  "approx_mars": approx_mars, "mars_gamma": mars_gamma,
222
222
  # Spectral Normalization
223
- "n_layers": n_layers, "spectral_normalization": spectral_normalization,
223
+ "spectral_normalization": spectral_normalization, "MSign_interval": MSign_interval,
224
224
  # Centered WD
225
225
  "centered_wd": centered_wd,
226
226
  "centered_wd_mode": centered_wd_mode,
@@ -154,6 +154,10 @@ class Prodigy_adv(torch.optim.Optimizer):
154
154
  # Centered WD
155
155
  centered_wd: float = 0.0,
156
156
  centered_wd_mode: str = 'float8',
157
+ # Spectral Normalization
158
+ spectral_normalization: bool = False,
159
+ # Orthogonalize the weights (Matrix Sign - MSign) every x steps
160
+ MSign_interval: int | None = None,
157
161
  ):
158
162
  if not (lr >= 0.0):
159
163
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -190,7 +194,8 @@ class Prodigy_adv(torch.optim.Optimizer):
190
194
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
191
195
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
192
196
  "centered_wd": centered_wd, "centered_wd_mode": centered_wd_mode,
193
- "nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd
197
+ "nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd,
198
+ "spectral_normalization": spectral_normalization, "MSign_interval": MSign_interval,
194
199
  }
195
200
  self.stochastic_rounding = stochastic_rounding
196
201
  self.fsdp_in_use = fsdp_in_use
@@ -79,6 +79,8 @@ class SignSGD_adv(torch.optim.Optimizer):
79
79
  state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'int8_sr'.
80
80
  # Spectral Normed Optimizer
81
81
  spectral_normalization: bool = False,
82
+ # Orthogonalize the weights (Matrix Sign - MSign) every x steps
83
+ MSign_interval: int | None = None,
82
84
  # SMMF factorization
83
85
  nnmf_factor: bool = False,
84
86
  vector_reshape: bool = False,
@@ -117,6 +119,7 @@ class SignSGD_adv(torch.optim.Optimizer):
117
119
  normed_momentum=normed_momentum,
118
120
  snr_cond=snr_cond,
119
121
  spectral_normalization=spectral_normalization,
122
+ MSign_interval=MSign_interval,
120
123
  centered_wd= centered_wd,
121
124
  centered_wd_mode= centered_wd_mode,
122
125
  state_precision=state_precision,
@@ -72,6 +72,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
72
72
  orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
73
73
  # Spectral Normed Optimizer
74
74
  spectral_normalization: bool = False,
75
+ # Orthogonalize the weights (Matrix Sign - MSign) every x steps
76
+ MSign_interval: int | None = None,
75
77
  # Centered WD
76
78
  centered_wd: float = 0.0,
77
79
  centered_wd_mode: str = 'float8',
@@ -108,7 +110,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
108
110
  "compiled_optimizer": compiled_optimizer,
109
111
  "sinkhorn_iterations": sinkhorn_iterations,
110
112
  "orthogonal_sinkhorn": orthogonal_sinkhorn,
111
- "spectral_normalization": spectral_normalization,
113
+ "spectral_normalization": spectral_normalization, "MSign_interval": MSign_interval,
112
114
  "centered_wd": centered_wd, "centered_wd_mode": centered_wd_mode,
113
115
  "state_precision": state_precision,
114
116
  "nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape
@@ -168,8 +168,6 @@ class KourkoutasHelper:
168
168
  # Update the persistent EMA tensor in-place.
169
169
  r_ema_tensor.mul_(ema_alpha).add_(pooled_grad_norm, alpha=1.0 - ema_alpha)
170
170
 
171
- tiny_spike = scale_tiny_spike(group, info['params'], tiny_spike)
172
-
173
171
  # Calculate Beta2
174
172
  raw = pooled_grad_norm / (r_ema_tensor + tiny_spike)
175
173
  sun = raw / (1.0 + raw)
@@ -255,29 +253,3 @@ class KourkoutasHelper:
255
253
  # The default is the max value, which is correct for unmapped params or edge cases
256
254
  beta2_default = group.get('betas', group.get('adam_betas'))[1] if group.get('betas', group.get('adam_betas')) else 0.999
257
255
  return self.layer_state.get(layer_key, {}).get('dynamic_beta2', beta2_default)
258
-
259
-
260
- def scale_tiny_spike(group: dict, layer_params: list, tiny_spike: float) -> float:
261
- """
262
- Derives scale-invariant tiny_spike from the EMA tensor's effective numel.
263
- """
264
- if not group.get('spectral_normalization', False):
265
- return tiny_spike
266
-
267
- p0 = layer_params[0]
268
- if getattr(p0, '_is_lora_A', False) or p0.ndim < 2:
269
- # No depth scaling for:
270
- # - lora_A: non-zero init, different gradient dynamics than B
271
- # - 1D params (biases, norms, DoRA scales): additive, don't compound through depth.
272
- L = 1
273
- else:
274
- L = group['n_layers']
275
-
276
- if getattr(p0, '_is_lora_A', False) or getattr(p0, '_is_oft', False):
277
- ema_numel = p0.shape[1] # (1, in_features)
278
- elif getattr(p0, '_is_lora_B', False):
279
- ema_numel = p0.shape[0] # (out_features, 1)
280
- else:
281
- ema_numel = sum(p.numel() for p in layer_params) # scalar EMA
282
-
283
- return 1.0 / (L * math.sqrt(ema_numel))
@@ -66,6 +66,7 @@ def _init_auxadam_state(self, p, group):
66
66
  _init_anchor(p, state, group)
67
67
  _init_fisher_wd_scaler(group, state, p)
68
68
 
69
+ group["MSign_interval"] = None
69
70
 
70
71
  @torch.no_grad()
71
72
  def _adam_step_parameter(self, p, grad, state, group, beta1_adam, beta2_adam, sqrt_bias_correction2, step_size, random_int_tensor, random_int_state_tensor=None):
@@ -0,0 +1,114 @@
1
+ import torch
2
+
3
+ import math
4
+
5
+ @torch.no_grad()
6
+ def msign_ortho_precond_(p):
7
+ """
8
+ Applies an orthogonal preconditioner to a parameter tensor in-place using Matrix Sign.
9
+
10
+ Originally proposed in the paper:
11
+ "MSign: An Optimizer Preventing Training Instability in Large Language Models via Stable Rank Restoration"(arxiv:2602.01734)
12
+ Modified to use CANS instead of SVD.
13
+
14
+ Args:
15
+ p (torch.Tensor): The parameter tensor to be preconditioned. Must be a dense
16
+ tensor. Vectors and biases should generally be excluded before calling this.
17
+ The operation modifies `p` in-place.
18
+
19
+ Note:
20
+ - The Frobenius norm of the matrix is strictly preserved.
21
+ - High-dimensional tensors (e.g., Conv2D weights) are reshaped into
22
+ `(out_channels, in_channels * ...)` automatically before computation.
23
+ """
24
+ # Record the original Frobenius norm of the weight matrix
25
+ orig_norm = torch.linalg.vector_norm(p, ord=2).clamp_min_(1e-12)
26
+ # Reshape parameter to 2D for the matrix operation
27
+ p_2d = p.view(p.shape[0], -1)
28
+ # Approximate the matrix sign UV^T using CA Newton-Schulz
29
+ p_sign = _cans_newton_schulz_iteration(
30
+ p_2d,
31
+ steps=15,
32
+ eps=1e-7,
33
+ cns_a_bound=None, # auto
34
+ )
35
+ # Calculate the Frobenius norm of the sign-projected matrix
36
+ sign_norm = torch.linalg.vector_norm(p_sign, ord=2).clamp_min_(1e-12)
37
+ # Restore original Frobenius norm scale and write back in-place
38
+ p.copy_(p_sign.mul_(orig_norm / sign_norm).view_as(p))
39
+
40
+ @torch.no_grad()
41
+ def _cans_newton_schulz_iteration(
42
+ G: torch.Tensor,
43
+ steps: int = 15,
44
+ eps: float = 1e-7,
45
+ cns_a_bound: float | None = None,
46
+ ) -> torch.Tensor:
47
+ """
48
+ Computes the matrix sign function using Chebyshev-Optimized Newton-Schulz (CANS) iterations.
49
+
50
+ This function iteratively approximates the orthogonalized version of a matrix (i.e., UV^T from SVD)
51
+ using purely matrix multiplications. This is highly optimized for Tensor Cores on modern GPUs
52
+ and is significantly faster than computing an exact SVD.
53
+
54
+ Args:
55
+ G (torch.Tensor): The input matrix/tensor to be orthogonalized. Must be at least 2D.
56
+ steps (int, optional): Number of Newton-Schulz iterations. Deep learning models
57
+ usually only require 5 or 6 steps for sufficient orthogonal preconditioner
58
+ precision due to quadratic convergence. Defaults to 6.
59
+ eps (float, optional): A small epsilon value used to prevent division by zero
60
+ during the initial spectral normalization. Defaults to 1e-7.
61
+ cns_a_bound (float | None, optional): The lower bound for singular values after
62
+ normalization. If None, it is automatically calculated based on the Marchenko-Pastur
63
+ theoretical minimum and baseline L2 bounds. Defaults to None.
64
+
65
+ Returns:
66
+ torch.Tensor: The approximate matrix sign of G (orthogonalized matrix), preserving
67
+ the original shape and device of G.
68
+ """
69
+ X = G
70
+
71
+ # Transpose if needed
72
+ transposed = X.size(-2) > X.size(-1)
73
+ if transposed:
74
+ X = X.mT
75
+
76
+ # Normalize spectral norm to at most 1
77
+ X.div_(X.norm(dim=(-2, -1), keepdim=True).clamp_min_(eps))
78
+
79
+ if cns_a_bound is None:
80
+ M = G.shape[-2]
81
+ N = G.shape[-1]
82
+ # baseline L2 norm bound (for square matrices)
83
+ baseline_bound = 1.0 / math.sqrt(M * N)
84
+ # Marchenko-Pastur theoretical minimum (for rectangular matrices)
85
+ mp_bound = abs(1.0 / math.sqrt(N) - 1.0 / math.sqrt(M))
86
+ # The optimal bound is safely the maximum of the two
87
+ cns_a_bound = max(baseline_bound, mp_bound)
88
+
89
+ lower_bound = cns_a_bound
90
+ upper_bound = 1.0
91
+ for _ in range(steps):
92
+ lb, ub = lower_bound, upper_bound
93
+ lb_ub = lb * ub
94
+ # Calculate Mean Square Error term
95
+ e_sq = (lb**2 + lb_ub + ub**2) / 3.0
96
+ # Calculate components for alpha and bounds update
97
+ K = 2.0 * e_sq**1.5
98
+ L = lb_ub * (lb + ub)
99
+ denom = K + L
100
+ alpha = 6.0 / denom
101
+ c1 = alpha * e_sq
102
+ c3 = -alpha / 3.0
103
+ # Apply the 3rd-order Newton-Schulz update
104
+ A = X @ X.mT
105
+ X = c1 * X + c3 * (A @ X)
106
+ # Update the singular value bounds for the next iteration based on the error
107
+ eps_val = (K - L) / denom
108
+ lower_bound, upper_bound = 1.0 - eps_val, 1.0 + eps_val
109
+
110
+ # Transpose back if necessary
111
+ if transposed:
112
+ X = X.mT
113
+
114
+ return X
@@ -8,6 +8,7 @@ from typing import Dict, Any
8
8
 
9
9
  from .scaled_optm import adjust_wds
10
10
  from .centered_decay import dequantize_anchor
11
+ from .msign import msign_ortho_precond_
11
12
 
12
13
  _generators: Dict[torch.device, torch.Generator] = {}
13
14
 
@@ -120,6 +121,11 @@ def apply_parameter_update(
120
121
 
121
122
  state = self.state[p]
122
123
 
124
+ ortho_interval = group.get('MSign_interval', None)
125
+ if ortho_interval is not None:
126
+ is_vector = p.ndim < 2 or getattr(p, '_is_dora_scale', False) or getattr(p, 'is_vector', False)
127
+ is_ortho_step = (state['step'] % ortho_interval == 0) and not is_vector
128
+
123
129
  # Compute full update in float32 if using bfloat16 with stochastic rounding
124
130
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
125
131
  p_fp32 = p.float()
@@ -135,6 +141,9 @@ def apply_parameter_update(
135
141
  # Apply main update
136
142
  p_fp32.add_(-update_fp32)
137
143
 
144
+ if is_ortho_step:
145
+ msign_ortho_precond_(p_fp32)
146
+
138
147
  # Single stochastic rounding at the end
139
148
  if random_int_tensor is not None:
140
149
  # Compiled path: use the pre-computed random tensor
@@ -153,6 +162,9 @@ def apply_parameter_update(
153
162
  # Apply main update
154
163
  p.add_(-update)
155
164
 
165
+ if is_ortho_step:
166
+ msign_ortho_precond_(p)
167
+
156
168
  del update
157
169
 
158
170
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.5.7
3
+ Version: 2.6.dev1
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -24,6 +24,7 @@ adv_optm/util/__init__.py
24
24
  adv_optm/util/centered_decay.py
25
25
  adv_optm/util/factorization_util.py
26
26
  adv_optm/util/lion_k.py
27
+ adv_optm/util/msign.py
27
28
  adv_optm/util/param_update.py
28
29
  adv_optm/util/scaled_optm.py
29
30
  adv_optm/util/signed_util.py
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.5.7",
8
+ version="2.6.dev1",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes