adv-optm 1.2.dev13__py3-none-any.whl → 2.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdamW_adv.py +85 -64
- adv_optm/optim/Adopt_adv.py +114 -69
- adv_optm/optim/Lion_Prodigy_adv.py +79 -81
- adv_optm/optim/Lion_adv.py +37 -42
- adv_optm/optim/Prodigy_adv.py +105 -85
- adv_optm/optim/Simplified_AdEMAMix.py +92 -51
- adv_optm/optim/__init__.py +1 -1
- adv_optm/util/BF16_Stochastic_Rounding.py +1 -1
- adv_optm/util/Effective_Shape.py +1 -1
- adv_optm/util/Kourkoutas.py +11 -12
- adv_optm/util/NNMF.py +7 -2
- adv_optm/util/Newton_Schulz.py +1 -2
- adv_optm/util/One_Bit_Boolean.py +1 -1
- adv_optm/util/OrthoGrad.py +4 -3
- adv_optm/util/__init__.py +1 -1
- {adv_optm-1.2.dev13.dist-info → adv_optm-2.dev1.dist-info}/METADATA +20 -20
- adv_optm-2.dev1.dist-info/RECORD +23 -0
- adv_optm-1.2.dev13.dist-info/RECORD +0 -23
- {adv_optm-1.2.dev13.dist-info → adv_optm-2.dev1.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev13.dist-info → adv_optm-2.dev1.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev13.dist-info → adv_optm-2.dev1.dist-info}/top_level.txt +0 -0
adv_optm/optim/Adopt_adv.py
CHANGED
|
@@ -56,13 +56,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
56
56
|
before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
|
|
57
57
|
A higher value increases the stabilizing influence of the slow
|
|
58
58
|
momentum. (default: 5.0)
|
|
59
|
-
t_alpha (Optional[int]): The number of steps for a linear warmup of the
|
|
60
|
-
`alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
|
|
61
|
-
highly recommended to prevent instability at the beginning of training,
|
|
62
|
-
as it gradually introduces the stabilizing slow momentum term. During
|
|
63
|
-
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
64
|
-
the scheduler is disabled and the full `alpha` value is used from
|
|
65
|
-
the start. (default: None)
|
|
66
59
|
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
67
60
|
This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
|
|
68
61
|
more responsive, especially for small batch sizes. Enabling this will
|
|
@@ -90,7 +83,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
90
83
|
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
91
84
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
92
85
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
93
|
-
logging (default: 0).
|
|
86
|
+
logging (default: 0).
|
|
94
87
|
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
95
88
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
96
89
|
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
@@ -107,7 +100,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
107
100
|
eps: float = 1e-6,
|
|
108
101
|
weight_decay: float = 0.0,
|
|
109
102
|
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
|
|
110
|
-
vector_reshape: bool =
|
|
103
|
+
vector_reshape: bool = False,
|
|
111
104
|
stochastic_rounding: bool = True,
|
|
112
105
|
use_atan2: bool = False,
|
|
113
106
|
cautious_mask: bool = False,
|
|
@@ -116,7 +109,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
116
109
|
use_AdEMAMix: bool = False,
|
|
117
110
|
beta3_ema: float = 0.9999,
|
|
118
111
|
alpha: float = 5.0,
|
|
119
|
-
t_alpha: int | None = None,
|
|
120
112
|
Simplified_AdEMAMix: bool = False,
|
|
121
113
|
alpha_grad: float = 100.0,
|
|
122
114
|
kourkoutas_beta: bool = False,
|
|
@@ -127,6 +119,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
127
119
|
k_logging: int = 0,
|
|
128
120
|
layer_key_fn: Optional[Callable] = None,
|
|
129
121
|
nnmf_factor: bool = False,
|
|
122
|
+
# Compiled
|
|
123
|
+
compiled_optimizer: bool = False,
|
|
130
124
|
):
|
|
131
125
|
if not (lr >= 0.0):
|
|
132
126
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -141,7 +135,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
141
135
|
cautious_mask = False
|
|
142
136
|
if betas[0] == 0.0 and Simplified_AdEMAMix:
|
|
143
137
|
raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
|
|
144
|
-
if kourkoutas_beta and not (betas[1] > beta2_min):
|
|
138
|
+
if kourkoutas_beta and not (betas[1] > beta2_min):
|
|
139
|
+
raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
145
140
|
if use_AdEMAMix and Simplified_AdEMAMix:
|
|
146
141
|
print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
|
|
147
142
|
if grams_moment and Simplified_AdEMAMix:
|
|
@@ -152,9 +147,10 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
152
147
|
defaults = {
|
|
153
148
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
154
149
|
"vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
|
|
155
|
-
"
|
|
150
|
+
"alpha_grad": alpha_grad,
|
|
156
151
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
157
152
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
153
|
+
"compiled_optimizer": compiled_optimizer,
|
|
158
154
|
}
|
|
159
155
|
self.clip_lambda = clip_lambda
|
|
160
156
|
self.stochastic_rounding = stochastic_rounding
|
|
@@ -169,9 +165,17 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
169
165
|
self.layer_key_fn = layer_key_fn
|
|
170
166
|
super().__init__(params, defaults)
|
|
171
167
|
|
|
168
|
+
self.init_step()
|
|
169
|
+
|
|
172
170
|
if self.kourkoutas_beta:
|
|
173
171
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
174
172
|
|
|
173
|
+
self.global_step = 0
|
|
174
|
+
|
|
175
|
+
if compiled_optimizer:
|
|
176
|
+
torch._dynamo.config.cache_size_limit = 8192
|
|
177
|
+
self.compile(fullgraph=True)
|
|
178
|
+
|
|
175
179
|
@property
|
|
176
180
|
def supports_fused_back_pass(self): return True
|
|
177
181
|
@property
|
|
@@ -179,29 +183,22 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
179
183
|
@property
|
|
180
184
|
def supports_flat_params(self): return False
|
|
181
185
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
+
def init_step(self):
|
|
187
|
+
for group in self.param_groups:
|
|
188
|
+
for p in group['params']:
|
|
189
|
+
self.__init_state(p, group)
|
|
186
190
|
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
grad = grad.float()
|
|
190
|
-
if self.orthogonal_gradient:
|
|
191
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
191
|
+
@torch.no_grad()
|
|
192
|
+
def __init_state(self, p, group):
|
|
192
193
|
state = self.state[p]
|
|
193
194
|
|
|
194
|
-
|
|
195
|
-
if 'step' not in state:
|
|
196
|
-
state['step'] = 0
|
|
195
|
+
if len(state) == 0:
|
|
197
196
|
|
|
198
|
-
|
|
197
|
+
state['factored'] = (
|
|
199
198
|
self.factored and
|
|
200
199
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
201
200
|
)
|
|
202
201
|
|
|
203
|
-
state['factored'] = should_factor
|
|
204
|
-
|
|
205
202
|
dtype = torch.float32 if self.factored else p.dtype
|
|
206
203
|
|
|
207
204
|
if state['factored']:
|
|
@@ -210,55 +207,75 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
210
207
|
|
|
211
208
|
# m_0 = 0
|
|
212
209
|
if group['betas'][0] > 0:
|
|
213
|
-
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
210
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
214
211
|
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
215
212
|
if not self.grams_moment:
|
|
216
213
|
packed_d2 = (d2 + 7) // 8
|
|
217
214
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
218
215
|
if self.use_AdEMAMix:
|
|
219
|
-
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
216
|
+
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
220
217
|
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
221
218
|
packed_d2 = (d2 + 7) // 8
|
|
222
219
|
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
223
|
-
|
|
224
|
-
vt_init = grad.view(d1, d2).square_()
|
|
225
|
-
# Allocate NMF factors for v
|
|
226
|
-
state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
227
|
-
state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
228
|
-
# Initialize v_0 using NMF
|
|
229
|
-
_nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
220
|
+
|
|
230
221
|
else: # Fallback for non-factored tensors
|
|
231
222
|
if group['betas'][0] > 0:
|
|
232
223
|
state['exp_avg'] = torch.zeros_like(p, dtype=dtype) # m_0
|
|
233
224
|
if self.use_AdEMAMix:
|
|
234
225
|
state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
|
|
235
|
-
|
|
226
|
+
|
|
227
|
+
@torch.no_grad()
|
|
228
|
+
def __init_step(self, p, group):
|
|
229
|
+
if p.grad is None:
|
|
230
|
+
return
|
|
231
|
+
|
|
232
|
+
state = self.state[p]
|
|
233
|
+
|
|
234
|
+
if 'exp_avg_sq' in state or 'mu_v_nmf' in state:
|
|
235
|
+
return
|
|
236
|
+
|
|
237
|
+
grad = p.grad
|
|
238
|
+
dtype = torch.float32 if self.factored else p.dtype
|
|
239
|
+
|
|
240
|
+
if state['factored']:
|
|
241
|
+
d1, d2 = state['effective_shape']
|
|
242
|
+
# v_0 = g_0^2 (SMMF_ADOPT NMF storage)
|
|
243
|
+
vt_init = grad.view(d1, d2).square_()
|
|
244
|
+
# Allocate NMF factors for v
|
|
245
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
246
|
+
state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
247
|
+
# Initialize v_0 using NMF
|
|
248
|
+
_nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
249
|
+
del vt_init
|
|
250
|
+
else:
|
|
251
|
+
state['exp_avg_sq'] = grad.square() # v_0
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@torch.no_grad()
|
|
255
|
+
def __step_parameter(self, p: torch.Tensor, group: dict, lr: torch.Tensor | float):
|
|
256
|
+
if p.grad is None:
|
|
257
|
+
return
|
|
258
|
+
|
|
259
|
+
grad = p.grad
|
|
260
|
+
if self.factored and grad.dtype != torch.float32:
|
|
261
|
+
grad = grad.float()
|
|
262
|
+
if self.orthogonal_gradient:
|
|
263
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
264
|
+
state = self.state[p]
|
|
265
|
+
|
|
236
266
|
|
|
237
267
|
beta1, beta2 = group['betas']
|
|
238
268
|
|
|
239
|
-
current_step = state['step']
|
|
240
269
|
if group.get('kourkoutas_beta', False):
|
|
241
|
-
# Call prepare_step() once at the beginning of the step for all params
|
|
242
|
-
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
243
270
|
# Accumulate current grad's norm for the *next* step
|
|
244
271
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
245
272
|
# Get the dynamic beta2 calculated in prepare_step()
|
|
246
|
-
beta2 = self.kourkoutas_helper.get_beta2(p, group
|
|
247
|
-
|
|
248
|
-
# The first step is for initialization only (skip when use_atan2 as it's scale invariant).
|
|
249
|
-
if state['step'] == 0 and not self.use_atan2:
|
|
250
|
-
state['step'] += 1
|
|
251
|
-
return
|
|
273
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
252
274
|
|
|
253
275
|
if self.use_AdEMAMix:
|
|
254
276
|
beta3_ema = group['beta3_ema']
|
|
255
277
|
alpha = group['alpha']
|
|
256
|
-
|
|
257
|
-
# Use step+1 for 1-based step count in scheduler
|
|
258
|
-
alpha_step = state['step'] + 1
|
|
259
|
-
alpha_t = alpha
|
|
260
|
-
if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
|
|
261
|
-
alpha_t = min(alpha_step * alpha / t_alpha, alpha)
|
|
278
|
+
|
|
262
279
|
if self.Simplified_AdEMAMix:
|
|
263
280
|
alpha_grad = group["alpha_grad"]
|
|
264
281
|
|
|
@@ -296,7 +313,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
296
313
|
else:
|
|
297
314
|
normalized_grad = grad_reshaped / denom.add_(group['eps'])
|
|
298
315
|
if self.clip_lambda is not None:
|
|
299
|
-
clip_val = self.clip_lambda(
|
|
316
|
+
clip_val = self.clip_lambda(self.global_step)
|
|
300
317
|
normalized_grad.clamp_(-clip_val, clip_val)
|
|
301
318
|
del denom
|
|
302
319
|
|
|
@@ -307,7 +324,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
307
324
|
else:
|
|
308
325
|
mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
309
326
|
if self.grams_moment:
|
|
310
|
-
mt = grad_reshaped.sign()
|
|
327
|
+
mt = grad_reshaped.sign().mul_(mt.abs())
|
|
311
328
|
elif self.cautious_mask:
|
|
312
329
|
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
313
330
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
@@ -317,9 +334,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
317
334
|
if self.use_AdEMAMix:
|
|
318
335
|
mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
319
336
|
if beta1 > 0:
|
|
320
|
-
update = torch.add(mt, mt_slow, alpha=
|
|
337
|
+
update = torch.add(mt, mt_slow, alpha=alpha)
|
|
321
338
|
else:
|
|
322
|
-
update = torch.add(normalized_grad, mt_slow, alpha=
|
|
339
|
+
update = torch.add(normalized_grad, mt_slow, alpha=alpha)
|
|
323
340
|
elif self.Simplified_AdEMAMix:
|
|
324
341
|
update = torch.add(mt, normalized_grad, alpha=alpha_grad)
|
|
325
342
|
else:
|
|
@@ -328,9 +345,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
328
345
|
update = update.view(p.shape)
|
|
329
346
|
|
|
330
347
|
if self.use_atan2:
|
|
331
|
-
update.mul_(
|
|
348
|
+
update.mul_(lr * 1.2732395447351628)
|
|
332
349
|
else:
|
|
333
|
-
update.mul_(
|
|
350
|
+
update.mul_(lr)
|
|
334
351
|
|
|
335
352
|
# Update second moment v_t for the *next* step using raw g_t
|
|
336
353
|
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
@@ -353,7 +370,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
353
370
|
del vt
|
|
354
371
|
|
|
355
372
|
else: # Standard ADOPT logic for non-factored tensors
|
|
356
|
-
v = state['exp_avg_sq'] # v_{t-1}
|
|
373
|
+
v = state['exp_avg_sq'] # v_{t-1}
|
|
357
374
|
|
|
358
375
|
# ADOPT Step A: Decorrelate g_t using v_{t-1}
|
|
359
376
|
denom = v.sqrt()
|
|
@@ -363,7 +380,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
363
380
|
else:
|
|
364
381
|
normalized_grad = grad / denom.add_(group['eps'])
|
|
365
382
|
if self.clip_lambda is not None:
|
|
366
|
-
clip_val = self.clip_lambda(
|
|
383
|
+
clip_val = self.clip_lambda(self.global_step)
|
|
367
384
|
normalized_grad.clamp_(-clip_val, clip_val)
|
|
368
385
|
del denom
|
|
369
386
|
|
|
@@ -376,7 +393,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
376
393
|
m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
377
394
|
|
|
378
395
|
if self.grams_moment:
|
|
379
|
-
m = grad.sign()
|
|
396
|
+
m = grad.sign().mul_(m.abs())
|
|
380
397
|
elif self.cautious_mask:
|
|
381
398
|
mask = (m * grad > 0).to(grad.dtype)
|
|
382
399
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
@@ -387,18 +404,18 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
387
404
|
m_slow = state['exp_avg_slow']
|
|
388
405
|
m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
389
406
|
if beta1 > 0:
|
|
390
|
-
update = torch.add(m, m_slow, alpha=
|
|
407
|
+
update = torch.add(m, m_slow, alpha=alpha)
|
|
391
408
|
else:
|
|
392
|
-
update = torch.add(normalized_grad, m_slow, alpha=
|
|
409
|
+
update = torch.add(normalized_grad, m_slow, alpha=alpha)
|
|
393
410
|
elif self.Simplified_AdEMAMix:
|
|
394
411
|
update = torch.add(m, normalized_grad, alpha=alpha_grad)
|
|
395
412
|
else:
|
|
396
413
|
update = m.clone() if beta1 > 0 else normalized_grad
|
|
397
414
|
|
|
398
415
|
if self.use_atan2:
|
|
399
|
-
update.mul_(
|
|
416
|
+
update.mul_(lr * 1.2732395447351628)
|
|
400
417
|
else:
|
|
401
|
-
update.mul_(
|
|
418
|
+
update.mul_(lr)
|
|
402
419
|
|
|
403
420
|
# Update second moment v_t for the next step using raw g_t
|
|
404
421
|
v.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
|
@@ -406,9 +423,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
406
423
|
# Parameter Update
|
|
407
424
|
if group["weight_decay"] != 0:
|
|
408
425
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
409
|
-
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] *
|
|
426
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * lr)
|
|
410
427
|
else:
|
|
411
|
-
p.data.add_(p.data, alpha=-group["weight_decay"] *
|
|
428
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * lr)
|
|
412
429
|
|
|
413
430
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
414
431
|
add_stochastic_(p.data, -update)
|
|
@@ -416,7 +433,33 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
416
433
|
p.data.add_(-update)
|
|
417
434
|
del update
|
|
418
435
|
|
|
419
|
-
|
|
436
|
+
@torch.no_grad()
|
|
437
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
438
|
+
if self.global_step is None and 'step' in self.state[p]:
|
|
439
|
+
# For backward compatibility
|
|
440
|
+
self.global_step = self.state[p]['step']
|
|
441
|
+
|
|
442
|
+
if self.global_step == 0:
|
|
443
|
+
self.__init_step(p, group)
|
|
444
|
+
|
|
445
|
+
# The first step is for initialization only (skip when use_atan2 as it's scale invariant).
|
|
446
|
+
if self.global_step == 0 and not self.use_atan2:
|
|
447
|
+
self.global_step += 1
|
|
448
|
+
return
|
|
449
|
+
|
|
450
|
+
if group.get('kourkoutas_beta', False):
|
|
451
|
+
# Prepare Kourkoutas-β once per step using the global step counter.
|
|
452
|
+
self.kourkoutas_helper.maybe_prepare_step(self.global_step)
|
|
453
|
+
|
|
454
|
+
if not group.get('compiled_optimizer', False):
|
|
455
|
+
self.__step_parameter(p, group, group['lr'])
|
|
456
|
+
else:
|
|
457
|
+
lr_tensor = torch.tensor(group['lr'], device=p.device)
|
|
458
|
+
self._compiled_step_parameter(p, group, lr_tensor)
|
|
459
|
+
|
|
460
|
+
def compile(self, *args, **kwargs):
|
|
461
|
+
self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
|
|
462
|
+
|
|
420
463
|
|
|
421
464
|
@torch.no_grad()
|
|
422
465
|
def step(self, closure=None):
|
|
@@ -430,4 +473,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
430
473
|
for i, p in enumerate(group['params']):
|
|
431
474
|
self.step_parameter(p, group, i)
|
|
432
475
|
|
|
433
|
-
|
|
476
|
+
self.global_step += 1
|
|
477
|
+
|
|
478
|
+
return loss
|