adv-optm 2.3.dev1__tar.gz → 2.3.dev3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/PKG-INFO +1 -1
  2. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/__init__.py +1 -1
  3. adv_optm-2.3.dev3/adv_optm/stiefel_optm/Stiefel_LoRA.py +231 -0
  4. adv_optm-2.3.dev3/adv_optm/stiefel_optm/__init__.py +5 -0
  5. adv_optm-2.3.dev3/adv_optm/stiefel_optm/stiefel_util.py +149 -0
  6. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm.egg-info/PKG-INFO +1 -1
  7. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm.egg-info/SOURCES.txt +3 -0
  8. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/setup.py +1 -1
  9. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/LICENSE +0 -0
  10. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/README.md +0 -0
  11. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/AdaMuon_adv.py +0 -0
  12. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/AdamW_adv.py +0 -0
  13. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/Adopt_adv.py +0 -0
  14. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  15. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/Lion_adv.py +0 -0
  16. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/Muon_adv.py +0 -0
  17. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/Prodigy_adv.py +0 -0
  18. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/SignSGD_adv.py +0 -0
  19. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  20. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/optim/__init__.py +0 -0
  21. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/Kourkoutas.py +0 -0
  22. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/Muon_AuxAdam.py +0 -0
  23. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/Muon_util.py +0 -0
  24. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/OrthoGrad.py +0 -0
  25. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/__init__.py +0 -0
  26. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/factorization_util.py +0 -0
  27. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/lion_k.py +0 -0
  28. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/param_update.py +0 -0
  29. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm/util/update_util.py +0 -0
  30. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm.egg-info/dependency_links.txt +0 -0
  31. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm.egg-info/requires.txt +0 -0
  32. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/adv_optm.egg-info/top_level.txt +0 -0
  33. {adv_optm-2.3.dev1 → adv_optm-2.3.dev3}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.3.dev1
3
+ Version: 2.3.dev3
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -28,4 +28,4 @@ __all__ = [
28
28
  "Stiefel_LoRA",
29
29
  ]
30
30
 
31
- __version__ = "2.3.dev1"
31
+ __version__ = "2.3.dev3"
@@ -0,0 +1,231 @@
1
+ import torch
2
+
3
+ from typing import Optional
4
+
5
+ from ..util import param_update
6
+ from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
7
+ from . import stiefel_util
8
+
9
+
10
+ class Stiefel_LoRA(torch.optim.Optimizer):
11
+ """
12
+ Implements an advanced Stiefel_LoRA algorithm.
13
+
14
+ based on the paper:
15
+ ""
16
+ In disguise it's modified SignSGD with momentum (SignUM).
17
+
18
+ Args:
19
+ params (iterable): iterable of parameters to optimize or dicts defining
20
+ parameter groups.
21
+ lr (float, optional): learning rate (default: 1e-4).
22
+ momentum (float, optional): coefficients for computing
23
+ running average of the gradients (default: 0.9).
24
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
25
+ cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
26
+ applied only to parameter coordinates where the sign of the parameter
27
+ and the sign of the optimizer update align (default: False).
28
+ vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
29
+ matrices to apply low-rank compression (default: True).
30
+ stochastic_rounding (bool, optional): whether to use stochastic
31
+ rounding for BF16 parameter updates (default: True).
32
+ Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
33
+ This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
34
+ more responsive, especially for small batch sizes. (default: False)
35
+ alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
36
+ (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
37
+ current gradient. For small batch sizes, use high values (e.g., 10-100) to be
38
+ more responsive. For large batch sizes, use low values (e.g., 0-1) for
39
+ stability. (default: 100.0)
40
+ nnmf_factor (bool): whether to use the factorization or use the
41
+ uncompressed optimizer. (default: True)
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ params,
47
+ lr: float = 1e-4,
48
+ momentum: float = 0.9,
49
+ # Decoupled/cautious weight decay
50
+ weight_decay: float = 0.0,
51
+ cautious_wd: bool = False,
52
+ # Stochastic Rounding for BF16
53
+ stochastic_rounding: bool = True,
54
+ # Simplified_AdEMAMix
55
+ alpha_grad: float = 1.0,
56
+ Simplified_AdEMAMix: bool = False,
57
+ # SMMF factorization
58
+ nnmf_factor: bool = False,
59
+ vector_reshape: bool = False,
60
+ # torch.compile
61
+ compiled_optimizer: bool = False,
62
+ ):
63
+ if not lr > 0.0:
64
+ raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
65
+ if not 0.0 <= momentum <= 1.0:
66
+ raise ValueError(f"momentum should be in [0.0, 1.0], but got {momentum}")
67
+ if not weight_decay >= 0.0:
68
+ raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
69
+
70
+ defaults = dict(
71
+ lr=lr,
72
+ momentum=momentum,
73
+ weight_decay=weight_decay,
74
+ cautious_wd=cautious_wd,
75
+ vector_reshape=vector_reshape,
76
+ alpha_grad=alpha_grad,
77
+ Simplified_AdEMAMix=Simplified_AdEMAMix,
78
+ nnmf_factor=nnmf_factor,
79
+ )
80
+ self.stochastic_rounding = stochastic_rounding
81
+ self._init_lr = lr
82
+
83
+ super().__init__(params, defaults)
84
+
85
+ if self.stochastic_rounding:
86
+ # For deterministic stochastic rounding, we need to seed the generator
87
+ # for each device used by the parameters.
88
+ devices = {p.device for group in self.param_groups for p in group['params'] if p.dtype == torch.bfloat16}
89
+ for device in devices:
90
+ param_update.set_seed(device)
91
+
92
+ # Initialize compiled function
93
+ self._compiled_step_parameter = None
94
+ if compiled_optimizer:
95
+ self.compile(fullgraph=True)
96
+
97
+ @property
98
+ def supports_fused_back_pass(self) -> bool:
99
+ return True
100
+
101
+ @property
102
+ def supports_memory_efficient_fp16(self) -> bool:
103
+ return True
104
+
105
+ @property
106
+ def supports_flat_params(self) -> bool:
107
+ return False
108
+
109
+ @torch.no_grad()
110
+ def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
111
+ """Performs a single optimization step on a single parameter."""
112
+ if p.grad is None:
113
+ return
114
+
115
+ grad = p.grad
116
+ state = self.state[p]
117
+
118
+ # State Initialization
119
+ if group["momentum"] > 0 and len(state) == 0:
120
+ state['factored'] = (
121
+ group['nnmf_factor'] and
122
+ not (len(p.shape) == 1 and not group['vector_reshape'])
123
+ )
124
+
125
+ dtype = torch.float32 if state['factored'] else p.dtype
126
+
127
+ if state['factored']:
128
+ state['effective_shape'] = _get_effective_shape(p.numel())
129
+ d1, d2 = state['effective_shape']
130
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
131
+ state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
132
+ packed_d2 = (d2 + 7) // 8
133
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
134
+ else:
135
+ state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
136
+
137
+ lr = group["lr"]
138
+
139
+ random_int_tensor = None
140
+
141
+ if group.get('compiled_optimizer', False):
142
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
143
+ # Pre-generate random tensor for stochastic rounding if needed.
144
+ random_int_tensor = param_update._get_random_int_for_sr(p)
145
+ lr = torch.as_tensor(lr, dtype=torch.float64)
146
+ step_param_fn = self._compiled_step_parameter
147
+ else:
148
+ step_param_fn = self._step_parameter
149
+
150
+ step_param_fn(p, grad, state, group, lr, random_int_tensor)
151
+
152
+ def _step_parameter(self, p, grad, state, group, lr, random_int_tensor):
153
+ if grad.dtype != torch.float32 and state['factored']:
154
+ grad = grad.float()
155
+
156
+ is_stiefel, is_stiefel_euclidean, is_scale = stiefel_util.set_flags_AB(p)
157
+
158
+ momentum = group["momentum"]
159
+ Simplified_AdEMAMix = group["Simplified_AdEMAMix"]
160
+ alpha_grad = group["alpha_grad"]
161
+
162
+ if state.get('factored'):
163
+ # Factored Path
164
+ d1, d2 = state['effective_shape']
165
+ grad_reshaped = grad.view(d1, d2)
166
+
167
+ if momentum > 0:
168
+ # Reconstruct momentum m_{t-1}
169
+ exp_avg = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True)
170
+ exp_avg.mul_(momentum).add_(grad_reshaped)
171
+
172
+ if Simplified_AdEMAMix:
173
+ raw_update = exp_avg + (grad_reshaped * alpha_grad)
174
+ else:
175
+ raw_update = exp_avg.clone()
176
+
177
+ # Compress new momentum m_t and store factors
178
+ state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(exp_avg, signed=True)
179
+ else:
180
+ raw_update = grad_reshaped.clone()
181
+
182
+ raw_update = raw_update.view(p.shape)
183
+
184
+ else:
185
+ # Fallback to standard SignSGD logic
186
+ if momentum > 0:
187
+ exp_avg = state["exp_avg"]
188
+ exp_avg.mul_(momentum).add_(grad)
189
+
190
+ if Simplified_AdEMAMix:
191
+ raw_update = exp_avg + (grad * alpha_grad)
192
+ else:
193
+ raw_update = exp_avg.clone()
194
+ else:
195
+ raw_update = grad.clone()
196
+
197
+ if is_stiefel:
198
+ update = raw_update
199
+ else:
200
+ update = raw_update.sign_()
201
+
202
+ if not is_stiefel:
203
+ # The RMS of a sign() tensor is 1.
204
+ # We scale it to 0.2 to match the approx RMS of AdamW
205
+ update = update.mul_(lr * 0.2)
206
+
207
+ stiefel_util.apply_stiefel_update(
208
+ self, p, group, update, lr,
209
+ random_int_tensor=random_int_tensor,
210
+ is_B=is_stiefel,
211
+ is_A=is_stiefel_euclidean,
212
+ is_scale=is_scale
213
+ )
214
+
215
+ def compile(self, *args, **kwargs):
216
+ self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
217
+
218
+ @torch.no_grad()
219
+ def step(self, closure: Optional[callable] = None):
220
+ """Performs a single optimization step."""
221
+ loss = None
222
+ if closure is not None:
223
+ with torch.enable_grad():
224
+ loss = closure()
225
+
226
+ for group in self.param_groups:
227
+ for i, p in enumerate(group["params"]):
228
+ if p.grad is not None:
229
+ self.step_parameter(p, group, i)
230
+
231
+ return loss
@@ -0,0 +1,5 @@
1
+ from .Stiefel_LoRA import Stiefel_LoRA
2
+
3
+ __all__ = [
4
+ "Stiefel_LoRA",
5
+ ]
@@ -0,0 +1,149 @@
1
+ import torch
2
+
3
+ from torch import Tensor
4
+
5
+ from typing import Dict, Any
6
+
7
+ def trangent_proj(p, update, lr):
8
+ """
9
+ [Stiefel-LoRA] Step 1: Tangent Space Projection
10
+ Formula: update = update - p @ sym(p.T @ update)
11
+ """
12
+ pt_u = torch.matmul(p.t(), update)
13
+ sym_pt_u = 0.5 * (pt_u + pt_u.t())
14
+ # Project the update onto the tangent space
15
+ update.sub_(torch.matmul(p, sym_pt_u))
16
+ del pt_u, sym_pt_u
17
+ rms_rescaling(update, lr)
18
+ return update
19
+
20
+ def qr_retraction(p):
21
+ """[Stiefel-LoRA] Step 2: Manifold Retraction (QR Decomposition)"""
22
+ Q, R = torch.linalg.qr(p)
23
+ d = R.diagonal().sign_()
24
+ Q *= d
25
+ return p.copy_(Q)
26
+
27
+ def rms_rescaling(update, lr):
28
+ """Rescales update to have RMS (Root Mean Square) of 0.2."""
29
+ # RMS = sqrt(mean(update**2))
30
+ # Frobenius Norm = sqrt(sum(update**2))
31
+ # Relationship: Norm = RMS * sqrt(numel)
32
+ numel = update.numel()
33
+ norm = torch.linalg.vector_norm(update).clamp_min_(1e-12)
34
+ target_norm = lr * (numel ** 0.5) * 0.2
35
+ return update.mul_(target_norm / norm)
36
+
37
+ def set_flags_AB(p):
38
+ """
39
+ Identify if parameter is LoRA A, B, or a scale parameter.
40
+ """
41
+ if getattr(p, '_is_dora_scale', False):
42
+ return False, False, True
43
+ if getattr(p, '_is_lora_B', False):
44
+ return True, False, False
45
+ if getattr(p, '_is_lora_A', False):
46
+ return False, True, False
47
+
48
+ # Fallback heuristic (handles 4D Conv2d layers properly)
49
+ dim0 = p.shape[0]
50
+ dim1 = p.shape[1] if p.ndim > 1 else 1
51
+
52
+ is_scale = p.ndim == 1 or (p.ndim == 2 and (dim0 == 1 or dim1 == 1))
53
+ if is_scale:
54
+ return False, False, True
55
+
56
+ B = dim0 > dim1
57
+ A = dim0 < dim1
58
+ return B, A, False
59
+
60
+ def apply_stiefel_update(
61
+ self,
62
+ p: Tensor,
63
+ group: Dict[str, Any],
64
+ update: Tensor,
65
+ lr: float | Tensor,
66
+ wd: float | None = None,
67
+ random_int_tensor: Tensor | None = None,
68
+ is_B: bool | None = None,
69
+ is_A: bool | None = None,
70
+ is_scale: bool | None = False,
71
+ ) -> None:
72
+ from ..util.param_update import _copy_stochastic_core_, copy_stochastic_
73
+ wd = group["weight_decay"] if wd is None else wd
74
+ cautious = group.get('cautious_wd', False)
75
+
76
+ if is_B or p.ndim == 1 or is_scale:
77
+ # Disable weight decay for the ortho matrix B or DoRA norm
78
+ wd = 0
79
+
80
+ if is_A:
81
+ # For matrix A, normalize weight decay by rank to make it invariant
82
+ wd = wd / p.shape[0]
83
+
84
+ scaled_wd = wd * (lr / self._init_lr)
85
+
86
+ # Compute full update in float32 if using bfloat16 with stochastic rounding
87
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
88
+ p_fp32 = p.float()
89
+ update_fp32 = update.float()
90
+
91
+ if is_B:
92
+ update_fp32 = trangent_proj(p_fp32, update_fp32, lr)
93
+ p_fp32.add_(-update_fp32)
94
+ p_fp32 = qr_retraction(p_fp32)
95
+ if random_int_tensor is not None:
96
+ _copy_stochastic_core_(p, p_fp32, random_int_tensor)
97
+ del random_int_tensor
98
+ else:
99
+ copy_stochastic_(p, p_fp32)
100
+ del p_fp32, update_fp32
101
+ return
102
+
103
+ # Apply weight decay if needed
104
+ if wd != 0:
105
+ if cautious:
106
+ # Cautious Weight Decay
107
+ mask = (update_fp32 * p_fp32 >= 0).float()
108
+ p_fp32.addcmul_(p_fp32, mask, value=-scaled_wd)
109
+ del mask
110
+ else:
111
+ # Standard decoupled weight decay
112
+ p_fp32.add_(p_fp32, alpha=-scaled_wd)
113
+
114
+ # Apply main update
115
+ p_fp32.add_(-update_fp32)
116
+
117
+ # Single stochastic rounding at the end
118
+ if random_int_tensor is not None:
119
+ # Compiled path: use the pre-computed random tensor
120
+ _copy_stochastic_core_(p, p_fp32, random_int_tensor)
121
+ del random_int_tensor
122
+ else:
123
+ # Uncompiled path: generate randoms inside
124
+ copy_stochastic_(p, p_fp32)
125
+ del p_fp32, update_fp32
126
+
127
+ else:
128
+ # Standard path for non-bfloat16 or without stochastic rounding
129
+
130
+ if is_B:
131
+ update = trangent_proj(p, update, lr)
132
+ p.add_(-update)
133
+ p = qr_retraction(p)
134
+ return
135
+
136
+ if wd != 0:
137
+ if cautious:
138
+ # Cautious Weight Decay
139
+ mask = (update * p >= 0).to(p.dtype)
140
+ p.addcmul_(p, mask, value=-scaled_wd)
141
+ del mask
142
+ else:
143
+ # Standard decoupled weight decay
144
+ p.add_(p, alpha=-scaled_wd)
145
+
146
+ # Apply main update
147
+ p.add_(-update)
148
+
149
+ del update
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.3.dev1
3
+ Version: 2.3.dev3
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -17,6 +17,9 @@ adv_optm/optim/Prodigy_adv.py
17
17
  adv_optm/optim/SignSGD_adv.py
18
18
  adv_optm/optim/Simplified_AdEMAMix.py
19
19
  adv_optm/optim/__init__.py
20
+ adv_optm/stiefel_optm/Stiefel_LoRA.py
21
+ adv_optm/stiefel_optm/__init__.py
22
+ adv_optm/stiefel_optm/stiefel_util.py
20
23
  adv_optm/util/Kourkoutas.py
21
24
  adv_optm/util/Muon_AuxAdam.py
22
25
  adv_optm/util/Muon_util.py
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.3.dev1",
8
+ version="2.3.dev3",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes