adv-optm 1.2.dev19__py3-none-any.whl → 2.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdaMuon_adv.py +11 -9
- adv_optm/optim/AdamW_adv.py +91 -61
- adv_optm/optim/Adopt_adv.py +113 -68
- adv_optm/optim/Lion_Prodigy_adv.py +79 -81
- adv_optm/optim/Lion_adv.py +59 -43
- adv_optm/optim/Muon_adv.py +13 -9
- adv_optm/optim/Prodigy_adv.py +108 -86
- adv_optm/optim/Simplified_AdEMAMix.py +93 -52
- adv_optm/optim/__init__.py +1 -1
- adv_optm/util/BF16_Stochastic_Rounding.py +1 -1
- adv_optm/util/Effective_Shape.py +1 -1
- adv_optm/util/Kourkoutas.py +10 -12
- adv_optm/util/NNMF.py +7 -2
- adv_optm/util/One_Bit_Boolean.py +1 -1
- adv_optm/util/OrthoGrad.py +4 -3
- adv_optm/util/__init__.py +1 -1
- {adv_optm-1.2.dev19.dist-info → adv_optm-2.dev3.dist-info}/METADATA +20 -20
- adv_optm-2.dev3.dist-info/RECORD +23 -0
- adv_optm-1.2.dev19.dist-info/RECORD +0 -23
- {adv_optm-1.2.dev19.dist-info → adv_optm-2.dev3.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev19.dist-info → adv_optm-2.dev3.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev19.dist-info → adv_optm-2.dev3.dist-info}/top_level.txt +0 -0
|
@@ -66,10 +66,10 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
66
66
|
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
67
67
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
68
68
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
69
|
-
logging (default: 0).
|
|
69
|
+
logging (default: 0).
|
|
70
70
|
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
71
71
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
72
|
-
If `None`, parameters are bucketed by their
|
|
72
|
+
If `None`, parameters are bucketed by their shape.
|
|
73
73
|
(default: None)
|
|
74
74
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
75
75
|
the uncompressed optimizer. (default: False)
|
|
@@ -86,7 +86,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
86
86
|
beta1_warmup: int | None = None,
|
|
87
87
|
min_beta1: float | None = 0.9,
|
|
88
88
|
use_bias_correction: bool = True,
|
|
89
|
-
vector_reshape: bool =
|
|
89
|
+
vector_reshape: bool = False,
|
|
90
90
|
stochastic_rounding: bool = True,
|
|
91
91
|
orthogonal_gradient: bool = False,
|
|
92
92
|
kourkoutas_beta: bool = False,
|
|
@@ -97,6 +97,8 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
97
97
|
k_logging: int = 0,
|
|
98
98
|
layer_key_fn: Optional[Callable] = None,
|
|
99
99
|
nnmf_factor: bool = False,
|
|
100
|
+
# Compiled
|
|
101
|
+
compiled_optimizer: bool = False,
|
|
100
102
|
):
|
|
101
103
|
if not (lr >= 0.0):
|
|
102
104
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -108,7 +110,8 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
108
110
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
109
111
|
if not 0.0 <= alpha_grad:
|
|
110
112
|
raise ValueError("Invalid alpha value: {}".format(alpha_grad))
|
|
111
|
-
if kourkoutas_beta and not (betas[1] > beta2_min):
|
|
113
|
+
if kourkoutas_beta and not (betas[1] > beta2_min):
|
|
114
|
+
raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
112
115
|
|
|
113
116
|
defaults = {
|
|
114
117
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
@@ -117,16 +120,33 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
117
120
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
118
121
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
119
122
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
123
|
+
"compiled_optimizer": compiled_optimizer,
|
|
120
124
|
}
|
|
121
125
|
self.stochastic_rounding = stochastic_rounding
|
|
122
126
|
self.factored = nnmf_factor
|
|
123
127
|
self.kourkoutas_beta = kourkoutas_beta
|
|
124
128
|
self.layer_key_fn = layer_key_fn
|
|
129
|
+
self.use_bias_correction = use_bias_correction
|
|
130
|
+
if use_bias_correction:
|
|
131
|
+
self.num_sum = betas[0] * 1.0
|
|
132
|
+
self.den_sum = betas[1] * (1.0 - betas[1])
|
|
133
|
+
else:
|
|
134
|
+
self.num_sum = 1.0
|
|
135
|
+
self.den_sum = 1.0
|
|
136
|
+
|
|
125
137
|
super().__init__(params, defaults)
|
|
126
138
|
|
|
139
|
+
self.init_step()
|
|
140
|
+
|
|
127
141
|
if self.kourkoutas_beta:
|
|
128
142
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
129
143
|
|
|
144
|
+
self.global_step = 0
|
|
145
|
+
|
|
146
|
+
if compiled_optimizer:
|
|
147
|
+
torch._dynamo.config.cache_size_limit = 8192
|
|
148
|
+
self.compile(fullgraph=True)
|
|
149
|
+
|
|
130
150
|
@property
|
|
131
151
|
def supports_fused_back_pass(self):
|
|
132
152
|
return True
|
|
@@ -139,29 +159,22 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
139
159
|
def supports_flat_params(self):
|
|
140
160
|
return False
|
|
141
161
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
162
|
+
def init_step(self):
|
|
163
|
+
for group in self.param_groups:
|
|
164
|
+
for p in group['params']:
|
|
165
|
+
self.__init_state(p, group)
|
|
146
166
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
grad = grad.float()
|
|
150
|
-
if group["orthogonal_gradient"]:
|
|
151
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
167
|
+
@torch.no_grad()
|
|
168
|
+
def __init_state(self, p, group):
|
|
152
169
|
state = self.state[p]
|
|
153
170
|
|
|
154
|
-
|
|
155
|
-
if 'step' not in state:
|
|
156
|
-
state['step'] = 0
|
|
171
|
+
if len(state) == 0:
|
|
157
172
|
|
|
158
|
-
|
|
173
|
+
state['factored'] = (
|
|
159
174
|
self.factored and
|
|
160
175
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
161
176
|
)
|
|
162
177
|
|
|
163
|
-
state['factored'] = should_factor
|
|
164
|
-
|
|
165
178
|
dtype = torch.float32 if self.factored else p.dtype
|
|
166
179
|
device = p.device
|
|
167
180
|
|
|
@@ -170,50 +183,42 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
170
183
|
d1, d2 = state['effective_shape']
|
|
171
184
|
|
|
172
185
|
# First moment (m)
|
|
173
|
-
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
186
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
174
187
|
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
175
188
|
packed_d2 = (d2 + 7) // 8
|
|
176
189
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
177
190
|
# Second moment (v)
|
|
178
|
-
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
191
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
179
192
|
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
180
193
|
else: # Fallback to standard optimizer for non-factored tensors
|
|
181
194
|
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
182
195
|
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
183
|
-
|
|
184
|
-
if group['use_bias_correction']:
|
|
185
|
-
state['num_sum'] = 0.0
|
|
186
|
-
state['den_sum'] = 0.0
|
|
187
|
-
else:
|
|
188
|
-
state['num_sum'] = 1.0
|
|
189
|
-
state['den_sum'] = 1.0
|
|
190
196
|
|
|
191
|
-
beta1_final, beta2 = group["betas"]
|
|
192
197
|
|
|
193
|
-
|
|
198
|
+
|
|
199
|
+
@torch.no_grad()
|
|
200
|
+
def __step_parameter(self, p: torch.Tensor, group: dict, lr: torch.Tensor | float, beta1, num_sum, den_sum):
|
|
201
|
+
if p.grad is None:
|
|
202
|
+
return
|
|
203
|
+
|
|
204
|
+
grad = p.grad
|
|
205
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
206
|
+
grad = grad.float()
|
|
207
|
+
if group["orthogonal_gradient"]:
|
|
208
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
209
|
+
state = self.state[p]
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
___, beta2 = group["betas"]
|
|
213
|
+
|
|
194
214
|
if group.get('kourkoutas_beta', False):
|
|
195
|
-
# Call prepare_step() once at the beginning of the step for all params
|
|
196
|
-
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
197
215
|
# Accumulate current grad's norm for the *next* step
|
|
198
216
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
199
217
|
# Get the dynamic beta2 calculated in prepare_step()
|
|
200
|
-
beta2 = self.kourkoutas_helper.get_beta2(p, group
|
|
218
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
201
219
|
|
|
202
|
-
beta1_warmup = group["beta1_warmup"]
|
|
203
220
|
alpha_grad = group["alpha_grad"]
|
|
204
221
|
|
|
205
|
-
if beta1_warmup is not None:
|
|
206
|
-
step = state['step'] + 1
|
|
207
|
-
beta1 = linear_hl_warmup_scheduler(step, beta_end=beta1_final, beta_start=group['min_beta1'], warmup=beta1_warmup)
|
|
208
|
-
else:
|
|
209
|
-
beta1 = beta1_final
|
|
210
|
-
|
|
211
|
-
if group['use_bias_correction']:
|
|
212
|
-
state['num_sum'] = beta1 * state['num_sum'] + 1.0
|
|
213
|
-
if group.get('kourkoutas_beta', False):
|
|
214
|
-
state['den_sum'] = group['betas'][1] * state['den_sum'] + (1.0 - group['betas'][1])
|
|
215
|
-
else:
|
|
216
|
-
state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
|
|
217
222
|
|
|
218
223
|
if state['factored']:
|
|
219
224
|
d1, d2 = state['effective_shape']
|
|
@@ -233,12 +238,12 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
233
238
|
update = torch.add(mt, grad_reshaped, alpha=alpha_grad)
|
|
234
239
|
del grad_reshaped
|
|
235
240
|
|
|
236
|
-
denom = vt.sqrt().add_(group['eps'] * math.sqrt(
|
|
241
|
+
denom = vt.sqrt().add_(group['eps'] * math.sqrt(den_sum))
|
|
237
242
|
update.div_(denom)
|
|
238
243
|
del denom
|
|
239
244
|
|
|
240
245
|
if group['use_bias_correction']:
|
|
241
|
-
update = (update /
|
|
246
|
+
update = (update / num_sum) * math.sqrt(den_sum)
|
|
242
247
|
|
|
243
248
|
update = update.view(p.shape).mul_(group['lr'])
|
|
244
249
|
|
|
@@ -259,12 +264,12 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
259
264
|
|
|
260
265
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
261
266
|
|
|
262
|
-
denom = exp_avg_sq.sqrt().add_(group['eps'] * math.sqrt(
|
|
267
|
+
denom = exp_avg_sq.sqrt().add_(group['eps'] * math.sqrt(den_sum))
|
|
263
268
|
update.div_(denom)
|
|
264
269
|
del denom
|
|
265
270
|
|
|
266
271
|
if group['use_bias_correction']:
|
|
267
|
-
update = (update /
|
|
272
|
+
update = (update / num_sum) * math.sqrt(den_sum)
|
|
268
273
|
|
|
269
274
|
update.mul_(group['lr'])
|
|
270
275
|
|
|
@@ -281,7 +286,36 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
281
286
|
p.data.add_(-update)
|
|
282
287
|
del update
|
|
283
288
|
|
|
284
|
-
|
|
289
|
+
|
|
290
|
+
@torch.no_grad()
|
|
291
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
292
|
+
if self.global_step is None and 'step' in self.state[p]:
|
|
293
|
+
# For backward compatibility
|
|
294
|
+
g_state = self.state[p]
|
|
295
|
+
self.global_step = g_state['step']
|
|
296
|
+
self.num_sum = group["betas"][0] * g_state['num_sum'] + 1.0
|
|
297
|
+
self.den_sum = group['betas'][1] * g_state['den_sum'] + (1.0 - group['betas'][1])
|
|
298
|
+
|
|
299
|
+
if group["beta1_warmup"] is not None:
|
|
300
|
+
step = self.global_step + 1
|
|
301
|
+
beta1 = linear_hl_warmup_scheduler(step, beta_end=group["betas"][0], beta_start=group['min_beta1'], warmup=group["beta1_warmup"])
|
|
302
|
+
else:
|
|
303
|
+
beta1 = group["betas"][0]
|
|
304
|
+
|
|
305
|
+
if group.get('kourkoutas_beta', False):
|
|
306
|
+
# Prepare Kourkoutas-β once per step using the global step counter.
|
|
307
|
+
self.kourkoutas_helper.maybe_prepare_step(self.global_step)
|
|
308
|
+
|
|
309
|
+
if not group.get('compiled_optimizer', False):
|
|
310
|
+
self.__step_parameter(p, group, group['lr'], beta1, self.num_sum, self.den_sum)
|
|
311
|
+
else:
|
|
312
|
+
lr_tensor = torch.tensor(group['lr'], device=p.device)
|
|
313
|
+
num_sum_tesnor = torch.tensor(self.num_sum, device=p.device)
|
|
314
|
+
den_sum_tesnor = torch.tensor(self.den_sum, device=p.device)
|
|
315
|
+
self._compiled_step_parameter(p, group, lr_tensor, beta1, self.num_sum, self.den_sum)
|
|
316
|
+
|
|
317
|
+
def compile(self, *args, **kwargs):
|
|
318
|
+
self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
|
|
285
319
|
|
|
286
320
|
@torch.no_grad()
|
|
287
321
|
def step(self, closure=None):
|
|
@@ -294,5 +328,12 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
294
328
|
for group in self.param_groups:
|
|
295
329
|
for i, p in enumerate(group['params']):
|
|
296
330
|
self.step_parameter(p, group, i)
|
|
331
|
+
|
|
332
|
+
g_group = self.param_groups[0]
|
|
333
|
+
if g_group['use_bias_correction']:
|
|
334
|
+
self.num_sum = g_group["betas"][0] * self.num_sum + 1.0
|
|
335
|
+
self.den_sum = g_group['betas'][1] * self.den_sum + (1.0 - g_group['betas'][1])
|
|
336
|
+
|
|
337
|
+
self.global_step += 1
|
|
297
338
|
|
|
298
|
-
return loss
|
|
339
|
+
return loss
|
adv_optm/optim/__init__.py
CHANGED
|
@@ -44,4 +44,4 @@ def add_stochastic_(input: Tensor, other: Tensor, alpha: float = 1.0):
|
|
|
44
44
|
result = other.clone() if other.dtype == torch.float32 else other.to(dtype=torch.float32)
|
|
45
45
|
|
|
46
46
|
result.add_(input, alpha=alpha)
|
|
47
|
-
copy_stochastic_(input, result)
|
|
47
|
+
copy_stochastic_(input, result)
|
adv_optm/util/Effective_Shape.py
CHANGED
adv_optm/util/Kourkoutas.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch.optim import Optimizer
|
|
3
|
-
from typing import Callable
|
|
4
3
|
|
|
5
4
|
class KourkoutasHelper:
|
|
6
5
|
"""
|
|
@@ -58,7 +57,7 @@ class KourkoutasHelper:
|
|
|
58
57
|
Calculates dynamic beta2 for all layers using the completed scalar accumulators
|
|
59
58
|
from the PREVIOUS step. Should be called once at the start of an optimizer step.
|
|
60
59
|
"""
|
|
61
|
-
|
|
60
|
+
|
|
62
61
|
beta2_log = []
|
|
63
62
|
# These are just for the sample log, initialize them
|
|
64
63
|
sun, pooled_grad_norm, r_ema_tensor = (torch.tensor(0.0),)*3
|
|
@@ -69,7 +68,7 @@ class KourkoutasHelper:
|
|
|
69
68
|
master_defaults = self.optimizer.defaults
|
|
70
69
|
|
|
71
70
|
for layer_key, info in self.layer_info.items():
|
|
72
|
-
|
|
71
|
+
group = info['group_ref']
|
|
73
72
|
|
|
74
73
|
if not group.get('kourkoutas_beta', False) and not group.get('adam_kourkoutas_beta', False):
|
|
75
74
|
continue
|
|
@@ -81,7 +80,7 @@ class KourkoutasHelper:
|
|
|
81
80
|
self.layer_state[layer_key] = {
|
|
82
81
|
'sum_sq_accumulator': torch.tensor(0.0, device=first_param_in_layer.device, dtype=torch.float32)
|
|
83
82
|
}
|
|
84
|
-
|
|
83
|
+
|
|
85
84
|
if 'kourkoutas_r_ema' not in param_state:
|
|
86
85
|
param_state['kourkoutas_r_ema'] = torch.tensor(0.0, device=first_param_in_layer.device, dtype=torch.float32)
|
|
87
86
|
|
|
@@ -96,14 +95,14 @@ class KourkoutasHelper:
|
|
|
96
95
|
|
|
97
96
|
r_ema_tensor = param_state['kourkoutas_r_ema']
|
|
98
97
|
accumulator = self.layer_state[layer_key]['sum_sq_accumulator']
|
|
99
|
-
|
|
98
|
+
|
|
100
99
|
pooled_grad_norm = torch.sqrt(accumulator)
|
|
101
|
-
|
|
100
|
+
|
|
102
101
|
# Update the persistent EMA tensor in-place.
|
|
103
102
|
r_ema_tensor.mul_(ema_alpha).add_(pooled_grad_norm, alpha=1.0 - ema_alpha)
|
|
104
|
-
|
|
103
|
+
|
|
105
104
|
sun = torch.tensor(0.0, device=r_ema_tensor.device) # Default sun to 0 for warmup
|
|
106
|
-
|
|
105
|
+
|
|
107
106
|
if current_step < k_warmup_steps:
|
|
108
107
|
beta2 = beta2_max
|
|
109
108
|
else:
|
|
@@ -113,7 +112,7 @@ class KourkoutasHelper:
|
|
|
113
112
|
|
|
114
113
|
# Store the final calculated beta2 in the helper's transient state for this step.
|
|
115
114
|
self.layer_state[layer_key]['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
|
|
116
|
-
|
|
115
|
+
|
|
117
116
|
# Reset the accumulator for the next optimizer step.
|
|
118
117
|
accumulator.zero_()
|
|
119
118
|
|
|
@@ -149,11 +148,10 @@ class KourkoutasHelper:
|
|
|
149
148
|
# Accumulate for the *next* step's prepare_step call
|
|
150
149
|
self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
|
|
151
150
|
|
|
152
|
-
def get_beta2(self, p: torch.Tensor, group: dict
|
|
151
|
+
def get_beta2(self, p: torch.Tensor, group: dict) -> float:
|
|
153
152
|
"""
|
|
154
153
|
Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
|
|
155
154
|
"""
|
|
156
155
|
layer_key = self.optimizer.layer_key_fn(p)
|
|
157
156
|
# The default is the max value, which is correct for unmapped params or edge cases
|
|
158
|
-
|
|
159
|
-
return self.layer_state.get(layer_key, {}).get('dynamic_beta2', beta2_default)
|
|
157
|
+
return self.layer_state.get(layer_key, {}).get('dynamic_beta2', group['betas'][1])
|
adv_optm/util/NNMF.py
CHANGED
|
@@ -9,10 +9,15 @@ def _nnmf(matrix: torch.Tensor, out: tuple):
|
|
|
9
9
|
shape = matrix.shape
|
|
10
10
|
torch.sum(matrix, dim=1, out=out[0])
|
|
11
11
|
torch.sum(matrix, dim=0, out=out[1])
|
|
12
|
+
|
|
13
|
+
# Add a small epsilon for numerical stability and to remove
|
|
14
|
+
# data-dependent branching, making it compatible with torch.dynamo.
|
|
15
|
+
epsilon = 1e-12
|
|
16
|
+
|
|
12
17
|
# Normalize one of the factors for stability
|
|
13
18
|
if shape[0] < shape[1]:
|
|
14
19
|
scale = out[0].sum()
|
|
15
|
-
|
|
20
|
+
out[0].div_(scale + epsilon)
|
|
16
21
|
else:
|
|
17
22
|
scale = out[1].sum()
|
|
18
|
-
|
|
23
|
+
out[1].div_(scale + epsilon)
|
adv_optm/util/One_Bit_Boolean.py
CHANGED
|
@@ -19,4 +19,4 @@ def _unpack_bools(packed_tensor: torch.Tensor, original_m: int) -> torch.Tensor:
|
|
|
19
19
|
shifter = (2**torch.arange(8, device=packed_tensor.device, dtype=torch.uint8)).view(1, 1, 8)
|
|
20
20
|
unpacked_padded = (packed_tensor.unsqueeze(2) & shifter) != 0
|
|
21
21
|
unpacked = unpacked_padded.view(packed_tensor.shape[0], -1)[:, :original_m]
|
|
22
|
-
return unpacked
|
|
22
|
+
return unpacked
|
adv_optm/util/OrthoGrad.py
CHANGED
|
@@ -2,15 +2,16 @@ import torch
|
|
|
2
2
|
|
|
3
3
|
def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
|
|
4
4
|
"""Projects the gradient `grad` to be orthogonal to the parameter `p`."""
|
|
5
|
-
if grad.is_sparse:
|
|
5
|
+
if grad.is_sparse:
|
|
6
|
+
raise RuntimeError("OrthoGrad logic does not support sparse gradients.")
|
|
6
7
|
original_shape = grad.shape
|
|
7
8
|
original_dtype = grad.dtype
|
|
8
9
|
w = p.view(-1).float()
|
|
9
10
|
g = grad.view(-1).float()
|
|
10
11
|
w_norm_sq = torch.dot(w, w).add_(1e-30)
|
|
11
12
|
proj = torch.dot(w, g) / w_norm_sq
|
|
12
|
-
g_orth = g.sub(w
|
|
13
|
+
g_orth = g.sub(w * proj)
|
|
13
14
|
g_norm = g.norm(2)
|
|
14
15
|
g_orth_norm = g_orth.norm(2).add_(1e-30)
|
|
15
16
|
g_orth_scaled = g_orth * (g_norm / g_orth_norm)
|
|
16
|
-
return g_orth_scaled.view(original_shape).to(original_dtype)
|
|
17
|
+
return g_orth_scaled.view(original_shape).to(original_dtype)
|
adv_optm/util/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: adv_optm
|
|
3
|
-
Version:
|
|
3
|
+
Version: 2.dev3
|
|
4
4
|
Summary: A family of highly efficient, lightweight yet powerful optimizers.
|
|
5
5
|
Home-page: https://github.com/Koratahiu/Advanced_Optimizers
|
|
6
6
|
Author: Koratahiu
|
|
@@ -52,7 +52,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
52
52
|
### **Memory-Efficient Optimization (SMMF-inspired)**
|
|
53
53
|
- **Paper**: [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
|
|
54
54
|
- **Approach**: Uses rank-1 non-negative matrix factorization with reconstruction cycle (factor → reconstruct → update → factor)
|
|
55
|
-
- **Innovation**:
|
|
55
|
+
- **Innovation**:
|
|
56
56
|
- First moment split into **1-bit sign + absolute value**
|
|
57
57
|
- Final storage: **four factored vectors + one 1-bit sign state**
|
|
58
58
|
- Preserves Adam-like update quality with drastically reduced memory
|
|
@@ -110,7 +110,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
110
110
|
|
|
111
111
|
## 🛠️ Comprehensive Feature Guide
|
|
112
112
|
|
|
113
|
-
### A. Universal Safe Features
|
|
113
|
+
### A. Universal Safe Features
|
|
114
114
|
*These features work with all optimizers and are generally safe to enable.*
|
|
115
115
|
|
|
116
116
|
| Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
|
|
@@ -165,7 +165,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
165
165
|
| `beta1` | 0.99 | Controls accumulator memory length:<br>• Small BS: **0.99–0.9999**<br>• Large BS: **0.9** |
|
|
166
166
|
| `Grad α` | 100 | Most critical parameter:<br>• Inversely scales with batch size<br>• **100–10** for small BS (≤32)<br>• **1–0.1** for large BS (≥512) |
|
|
167
167
|
|
|
168
|
-
> ⚠️ **Critical**: Requires **~100x smaller learning rate** than AdamW (e.g., 1e-6 vs 1e-4).
|
|
168
|
+
> ⚠️ **Critical**: Requires **~100x smaller learning rate** than AdamW (e.g., 1e-6 vs 1e-4).
|
|
169
169
|
> For `Prodigy_Adv`, set `initial_d` to:
|
|
170
170
|
> - **LoRA**: `1e-8`
|
|
171
171
|
> - **Full FT**: `1e-10`
|
|
@@ -175,10 +175,10 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
175
175
|
|
|
176
176
|
#### Performance Validation
|
|
177
177
|
|
|
178
|
-
**Small Batch Training (SDXL, BS=2, 1.8K steps)**
|
|
178
|
+
**Small Batch Training (SDXL, BS=2, 1.8K steps)**
|
|
179
179
|

|
|
180
180
|
|
|
181
|
-
- **🟢 Prodigy_Adv** (beta1=0.9, d0=1e-5): Final LR = 2.9e-4
|
|
181
|
+
- **🟢 Prodigy_Adv** (beta1=0.9, d0=1e-5): Final LR = 2.9e-4
|
|
182
182
|
- **🔵 Prodigy_Adv + Simplified_AdEMAMix** (beta1=0.99, α=100, d0=1e-7): Final LR = 5.8e-6
|
|
183
183
|
|
|
184
184
|
**Results**:
|
|
@@ -202,8 +202,8 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
202
202
|
|
|
203
203
|
Instead of using a fixed β₂ (e.g., 0.999 or 0.95), it **dynamically modulates β₂ per layer** based on a bounded *sunspike ratio*:
|
|
204
204
|
|
|
205
|
-
- **During gradient bursts** → β₂ ↓ toward `Lower β₂` → faster reaction
|
|
206
|
-
- **During calm phases** → β₂ ↑ toward `The Selected β₂` → stronger smoothing
|
|
205
|
+
- **During gradient bursts** → β₂ ↓ toward `Lower β₂` → faster reaction
|
|
206
|
+
- **During calm phases** → β₂ ↑ toward `The Selected β₂` → stronger smoothing
|
|
207
207
|
|
|
208
208
|
This is especially effective for **noisy training, small batch sizes, and high learning rates**, where gradient norms shift abruptly due to noise or aggressive LR schedules.
|
|
209
209
|
|
|
@@ -220,17 +220,17 @@ This is especially effective for **noisy training, small batch sizes, and high l
|
|
|
220
220
|
|
|
221
221
|
#### 📊 Performance Validation
|
|
222
222
|
|
|
223
|
-
**ADAMW_ADV - full SDXL finetuning (aggressive LR: 3e-5) (BS=4, 2.5K steps)**
|
|
223
|
+
**ADAMW_ADV - full SDXL finetuning (aggressive LR: 3e-5) (BS=4, 2.5K steps)**
|
|
224
224
|
<img width="1460" height="382" alt="image" src="https://github.com/user-attachments/assets/007f278a-fbac-4f3d-9cc7-274c3b959cdd" />
|
|
225
225
|
|
|
226
|
-
- 🟣 Fixed `beta2=0.999`
|
|
227
|
-
- 🟠 Auto K-beta
|
|
226
|
+
- 🟣 Fixed `beta2=0.999`
|
|
227
|
+
- 🟠 Auto K-beta
|
|
228
228
|
|
|
229
|
-
**Observations:**
|
|
229
|
+
**Observations:**
|
|
230
230
|
- K-beta is clearly better and more robust/stable for high LRs.
|
|
231
231
|
|
|
232
|
-
> 📚 **Reference**:
|
|
233
|
-
> - Paper: [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
|
|
232
|
+
> 📚 **Reference**:
|
|
233
|
+
> - Paper: [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
|
|
234
234
|
> - Code: [kbeta](https://github.com/sck-at-ucy/kbeta)
|
|
235
235
|
|
|
236
236
|
---
|
|
@@ -258,7 +258,7 @@ settings:
|
|
|
258
258
|
- factored: False # Can be true or false, quality should not degrade due to Simplified_AdEMAMix’s high tolerance to 1-bit factorization.
|
|
259
259
|
```
|
|
260
260
|
|
|
261
|
-
> ✅ **Why it works**:
|
|
261
|
+
> ✅ **Why it works**:
|
|
262
262
|
> - `Kourkoutas-β` handles beta2 values
|
|
263
263
|
> - `Simplified_AdEMAMix` ensures responsiveness in small-batch noise
|
|
264
264
|
> - `OrthoGrad` prevents overfitting without weight decay
|
|
@@ -267,9 +267,9 @@ settings:
|
|
|
267
267
|
|
|
268
268
|
## 📚 References
|
|
269
269
|
|
|
270
|
-
1. [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192)
|
|
271
|
-
2. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
|
|
272
|
-
3. [The AdEMAMix Optimizer](https://arxiv.org/abs/2409.03137)
|
|
273
|
-
4. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD](https://arxiv.org/abs/2502.02431)
|
|
274
|
-
5. [AdaMeM: Memory Efficient Momentum for Adafactor](https://openreview.net/forum?id=fZqMVTz7K5)
|
|
270
|
+
1. [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192)
|
|
271
|
+
2. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
|
|
272
|
+
3. [The AdEMAMix Optimizer](https://arxiv.org/abs/2409.03137)
|
|
273
|
+
4. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD](https://arxiv.org/abs/2502.02431)
|
|
274
|
+
5. [AdaMeM: Memory Efficient Momentum for Adafactor](https://openreview.net/forum?id=fZqMVTz7K5)
|
|
275
275
|
6. [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
adv_optm/__init__.py,sha256=5x-lSBvBDU_bD2E4mS2a4b2ElrfIgg9kHmQdBUwghbk,379
|
|
2
|
+
adv_optm/optim/AdaMuon_adv.py,sha256=zjZHFS7ng5KwemQzePjFiGtNZlcgbzmmnqF6A80h_Tg,34652
|
|
3
|
+
adv_optm/optim/AdamW_adv.py,sha256=VC6NpR9lDaRS6CIDIWdEXE_-2Z1opa0lXCxYZy8FEEI,18242
|
|
4
|
+
adv_optm/optim/Adopt_adv.py,sha256=FRYaqCyxzxUzt1geQj00WCWX0_71_8-cQyVNXaZeVBU,21898
|
|
5
|
+
adv_optm/optim/Lion_Prodigy_adv.py,sha256=ZRld1lt2wQtMBideFA-FStStvfV_oEMCrswww5rYAso,14103
|
|
6
|
+
adv_optm/optim/Lion_adv.py,sha256=GNkuFIwIjKwQElXjVbwjfwhe4lv4D_Qb0gbOjHl151g,8452
|
|
7
|
+
adv_optm/optim/Muon_adv.py,sha256=d91wvmKKt_3IPqsqK1ZZ5cY71kuXyzy04IU3krn2NQ8,33316
|
|
8
|
+
adv_optm/optim/Prodigy_adv.py,sha256=jY7zEWJ49ICqBERFf1fue126sZg0-o1Mu7M9pa_66Gs,26529
|
|
9
|
+
adv_optm/optim/Simplified_AdEMAMix.py,sha256=4JGaX6DDm0zdY8NxXzRIGm4pqb33on8Xw-uImxO3WNE,14399
|
|
10
|
+
adv_optm/optim/__init__.py,sha256=F4f-D8QGIByXHAZAu0keJf4foA22NpK-L9QgywVxAm8,491
|
|
11
|
+
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=b8bE7xGtJxZnQYCqdPKtYb8xYGrDftO6jCLLKLa9Ut8,1550
|
|
12
|
+
adv_optm/util/Effective_Shape.py,sha256=h9pF4HaCkjDyo2dxlUpM66oD6FtclQnb7yPPfvReHyI,320
|
|
13
|
+
adv_optm/util/Kourkoutas.py,sha256=8Lik30MACDwM77aNWmMecmPS9g31fT4jE6fuIG4QMTk,7366
|
|
14
|
+
adv_optm/util/NNMF.py,sha256=hrvNGERj8evhPIWnWzsKdm5DwIZblTB4pkhc9xWytSY,794
|
|
15
|
+
adv_optm/util/Newton_Schulz.py,sha256=bBboYw_jm5_FMf0Citl79uqNedkHOTjQnUI7rZgLBmY,3341
|
|
16
|
+
adv_optm/util/One_Bit_Boolean.py,sha256=tE8lSnbKR3oO-EtM0Kzvf0E4hmuBvhmtFR_75su-DNI,1070
|
|
17
|
+
adv_optm/util/OrthoGrad.py,sha256=doP667YpdiEdP3-cpyWiRNkAdkT-nzs45VSafOCRDHw,713
|
|
18
|
+
adv_optm/util/__init__.py,sha256=cA5zt5dvznkOw2lqbaGvFjslznB1UEFYYZMMFsXrWBg,437
|
|
19
|
+
adv_optm-2.dev3.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
20
|
+
adv_optm-2.dev3.dist-info/METADATA,sha256=ttkFBXVB97D9Fi3_AgO2bA9b-x-9sm0YSKujVtSLuBU,13983
|
|
21
|
+
adv_optm-2.dev3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
adv_optm-2.dev3.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
23
|
+
adv_optm-2.dev3.dist-info/RECORD,,
|
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=1AKxG--scx5Bl9G08tQcnfzAMaQVSgmW99uy3v2QWMw,380
|
|
2
|
-
adv_optm/optim/AdaMuon_adv.py,sha256=7Had92OcsCiN1E9UJRyrpPV7VzHqmIvS-qM6OEcc24I,34671
|
|
3
|
-
adv_optm/optim/AdamW_adv.py,sha256=jgMuRAfsnUh_2wUEZgYpJX5uwoT_kQjtMs2Xn2vJ3x0,17480
|
|
4
|
-
adv_optm/optim/Adopt_adv.py,sha256=kbAeBG4bXWBvgj_qrE9W67J6c0swpEi4Erj2rfYrMXE,21252
|
|
5
|
-
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
6
|
-
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
7
|
-
adv_optm/optim/Muon_adv.py,sha256=tZY8K3pNBCGk1V09GbK05lJooFw92NfkF7_T548up3Q,33171
|
|
8
|
-
adv_optm/optim/Prodigy_adv.py,sha256=k7f2J_RQpnrUXjwER_XOokISlQWpTSwGG-OL-bjMfBk,26061
|
|
9
|
-
adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
|
|
10
|
-
adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
|
|
11
|
-
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
12
|
-
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
13
|
-
adv_optm/util/Kourkoutas.py,sha256=BnBj4WlTOJXOW0dv_vBBE27HxDTbI_1qDIWW2J7Bxdo,7644
|
|
14
|
-
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
15
|
-
adv_optm/util/Newton_Schulz.py,sha256=bBboYw_jm5_FMf0Citl79uqNedkHOTjQnUI7rZgLBmY,3341
|
|
16
|
-
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
17
|
-
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
18
|
-
adv_optm/util/__init__.py,sha256=CXzS703GB4gil85khZi7sgKOnbzXGBOltshIOSPqj18,435
|
|
19
|
-
adv_optm-1.2.dev19.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
20
|
-
adv_optm-1.2.dev19.dist-info/METADATA,sha256=pQm5WuMKvf5Xse10viziVK9ry1UufcYRDwOd55jad8Y,14023
|
|
21
|
-
adv_optm-1.2.dev19.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
-
adv_optm-1.2.dev19.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
23
|
-
adv_optm-1.2.dev19.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|