adv-optm 1.1.1__py3-none-any.whl → 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdamW_adv.py +2 -2
- adv_optm/optim/Adopt_adv.py +1 -1
- adv_optm/optim/Lion_Prodigy_adv.py +62 -36
- adv_optm/optim/Prodigy_adv.py +2 -2
- adv_optm/optim/Simplified_AdEMAMix.py +2 -2
- {adv_optm-1.1.1.dist-info → adv_optm-1.1.3.dist-info}/METADATA +1 -1
- {adv_optm-1.1.1.dist-info → adv_optm-1.1.3.dist-info}/RECORD +11 -11
- {adv_optm-1.1.1.dist-info → adv_optm-1.1.3.dist-info}/WHEEL +0 -0
- {adv_optm-1.1.1.dist-info → adv_optm-1.1.3.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.1.1.dist-info → adv_optm-1.1.3.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdamW_adv.py
CHANGED
|
@@ -209,7 +209,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
209
209
|
beta1, beta2 = group['betas']
|
|
210
210
|
|
|
211
211
|
current_step = state['step']
|
|
212
|
-
if group
|
|
212
|
+
if group.get('kourkoutas_beta', False):
|
|
213
213
|
# Call prepare_step() once at the beginning of the step for all params
|
|
214
214
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
215
215
|
# Accumulate current grad's norm for the *next* step
|
|
@@ -220,7 +220,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
220
220
|
step = state['step'] + 1
|
|
221
221
|
if group['use_bias_correction']:
|
|
222
222
|
bias_correction1 = 1.0 - beta1 ** step
|
|
223
|
-
if group
|
|
223
|
+
if group.get('kourkoutas_beta', False):
|
|
224
224
|
bias_correction2 = 1.0 - group['betas'][1] ** step
|
|
225
225
|
# Use beta2_max for bias correction
|
|
226
226
|
else:
|
adv_optm/optim/Adopt_adv.py
CHANGED
|
@@ -240,7 +240,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
240
240
|
beta1, beta2 = group['betas']
|
|
241
241
|
|
|
242
242
|
current_step = state['step']
|
|
243
|
-
if group
|
|
243
|
+
if group.get('kourkoutas_beta', False):
|
|
244
244
|
# Call prepare_step() once at the beginning of the step for all params
|
|
245
245
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
246
246
|
# Accumulate current grad's norm for the *next* step
|
|
@@ -50,6 +50,12 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
50
50
|
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
51
51
|
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
52
52
|
Prodigy. Values ~11 are reasonable (default 11).
|
|
53
|
+
prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
|
|
54
|
+
after the specified optimiser step and release all state memory required by Prodigy
|
|
55
|
+
(default: 0).
|
|
56
|
+
d_limiter (bool): whether to clamp the new step size estimate (`d_hat`)
|
|
57
|
+
to prevent sudden, volatile increases in the adaptive step size (`d`).
|
|
58
|
+
(default: True)
|
|
53
59
|
"""
|
|
54
60
|
|
|
55
61
|
def __init__(
|
|
@@ -63,7 +69,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
63
69
|
orthogonal_gradient: bool = False,
|
|
64
70
|
cautious_mask: bool = False,
|
|
65
71
|
clip_threshold: float = 0.0,
|
|
66
|
-
nnmf_factor: bool =
|
|
72
|
+
nnmf_factor: bool = False,
|
|
67
73
|
# prodigy parameters
|
|
68
74
|
beta3: float = None,
|
|
69
75
|
d0: float = 1e-6,
|
|
@@ -72,6 +78,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
72
78
|
safeguard_warmup: bool = False,
|
|
73
79
|
fsdp_in_use: bool = False,
|
|
74
80
|
slice_p: int = 11,
|
|
81
|
+
prodigy_steps: int = 0,
|
|
82
|
+
d_limiter: bool = True,
|
|
75
83
|
):
|
|
76
84
|
if not lr > 0.0:
|
|
77
85
|
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
@@ -90,6 +98,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
90
98
|
beta3=beta3, d=d0, d0=d0, d_max=d0, d_numerator=0.0, d_coef=d_coef,
|
|
91
99
|
growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, k=0, slice_p=slice_p,
|
|
92
100
|
fsdp_in_use=fsdp_in_use,
|
|
101
|
+
prodigy_steps=prodigy_steps,
|
|
102
|
+
d_limiter=d_limiter,
|
|
93
103
|
)
|
|
94
104
|
self.stochastic_rounding = stochastic_rounding
|
|
95
105
|
self.cautious_mask = cautious_mask
|
|
@@ -235,20 +245,28 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
235
245
|
# Update momentum
|
|
236
246
|
exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
|
|
237
247
|
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
248
|
+
prodigy_steps = group['prodigy_steps']
|
|
249
|
+
if prodigy_steps <= 0 or group['k'] < prodigy_steps:
|
|
250
|
+
# --- Accumulate Prodigy stats ---
|
|
251
|
+
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
252
|
+
s, p0 = state['s'], state['p0']
|
|
253
|
+
grad_flat = grad.flatten().float()
|
|
254
|
+
p_flat = p.data.flatten().float()
|
|
255
|
+
p0 = p0.float()
|
|
244
256
|
|
|
245
|
-
|
|
257
|
+
self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
|
|
246
258
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
259
|
+
alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
|
|
260
|
+
s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
|
|
261
|
+
self.d_denom += s.abs().sum().item()
|
|
250
262
|
|
|
251
|
-
|
|
263
|
+
del s, p0, grad_flat, p_flat, alpha
|
|
264
|
+
else:
|
|
265
|
+
# Free memory if prodigy_steps is reached
|
|
266
|
+
if 's' in state:
|
|
267
|
+
del state['s']
|
|
268
|
+
if 'p0' in state:
|
|
269
|
+
del state['p0']
|
|
252
270
|
|
|
253
271
|
if group["weight_decay"] != 0:
|
|
254
272
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
@@ -287,29 +305,37 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
287
305
|
def calculate_d(self):
|
|
288
306
|
"""Calculates the new `d` based on the accumulated stats."""
|
|
289
307
|
g_group = self.param_groups[0]
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
dist.
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
308
|
+
# Only perform d-adaptation if prodigy_steps has not been reached
|
|
309
|
+
prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
|
|
310
|
+
|
|
311
|
+
if prodigy_active:
|
|
312
|
+
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
313
|
+
|
|
314
|
+
if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
|
|
315
|
+
# Use the device of the first parameter to avoid hardcoding '.cuda()'
|
|
316
|
+
device = self.param_groups[0]['params'][0].device
|
|
317
|
+
dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
|
|
318
|
+
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
|
|
319
|
+
global_d_numerator = dist_tensor[0].item()
|
|
320
|
+
global_d_denom = dist_tensor[1].item()
|
|
321
|
+
else:
|
|
322
|
+
global_d_numerator = self.d_numerator
|
|
323
|
+
global_d_denom = self.d_denom
|
|
324
|
+
|
|
325
|
+
d_hat = self.d
|
|
326
|
+
if global_d_denom > 0:
|
|
327
|
+
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
328
|
+
if g_group.get('d_limiter', False):
|
|
329
|
+
d_hat = min(self.d * (2 ** 0.25), d_hat)
|
|
330
|
+
if self.d == g_group['d0']:
|
|
331
|
+
self.d = max(self.d, d_hat)
|
|
332
|
+
d_max = max(d_max, d_hat)
|
|
333
|
+
self.d = min(d_max, self.d * growth_rate)
|
|
334
|
+
|
|
335
|
+
for group in self.param_groups:
|
|
336
|
+
group['d_numerator'] = global_d_numerator
|
|
337
|
+
group['d'] = self.d
|
|
338
|
+
group['d_max'] = d_max
|
|
339
|
+
# Increment step counter for all groups, regardless of whether d was updated
|
|
311
340
|
for group in self.param_groups:
|
|
312
|
-
group['d_numerator'] = global_d_numerator
|
|
313
|
-
group['d'] = self.d
|
|
314
|
-
group['d_max'] = d_max
|
|
315
341
|
group['k'] += 1
|
adv_optm/optim/Prodigy_adv.py
CHANGED
|
@@ -304,7 +304,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
304
304
|
state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
|
|
305
305
|
|
|
306
306
|
current_step = state['step']
|
|
307
|
-
if group
|
|
307
|
+
if group.get('kourkoutas_beta', False):
|
|
308
308
|
# Call prepare_step() once at the beginning of the step for all params
|
|
309
309
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
310
310
|
# Accumulate current grad's norm for the *next* step
|
|
@@ -515,7 +515,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
515
515
|
d_hat = self.d
|
|
516
516
|
if global_d_denom > 0:
|
|
517
517
|
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
518
|
-
if g_group
|
|
518
|
+
if g_group.get('d_limiter', False):
|
|
519
519
|
d_hat = min(self.d * (2 ** 0.25), d_hat)
|
|
520
520
|
if self.d == g_group['d0']:
|
|
521
521
|
self.d = max(self.d, d_hat)
|
|
@@ -191,7 +191,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
191
191
|
beta1_final, beta2 = group["betas"]
|
|
192
192
|
|
|
193
193
|
current_step = state['step']
|
|
194
|
-
if group
|
|
194
|
+
if group.get('kourkoutas_beta', False):
|
|
195
195
|
# Call prepare_step() once at the beginning of the step for all params
|
|
196
196
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
197
197
|
# Accumulate current grad's norm for the *next* step
|
|
@@ -210,7 +210,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
210
210
|
|
|
211
211
|
if group['use_bias_correction']:
|
|
212
212
|
state['num_sum'] = beta1 * state['num_sum'] + 1.0
|
|
213
|
-
if group
|
|
213
|
+
if group.get('kourkoutas_beta', False):
|
|
214
214
|
state['den_sum'] = group['betas'][1] * state['den_sum'] + (1.0 - group['betas'][1])
|
|
215
215
|
else:
|
|
216
216
|
state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=
|
|
2
|
-
adv_optm/optim/AdamW_adv.py,sha256=
|
|
3
|
-
adv_optm/optim/Adopt_adv.py,sha256=
|
|
4
|
-
adv_optm/optim/Lion_Prodigy_adv.py,sha256=
|
|
1
|
+
adv_optm/__init__.py,sha256=9UZMsxIFudooscrxW4TwKgj3PkrKdC5ZFEOAkYpkrMw,306
|
|
2
|
+
adv_optm/optim/AdamW_adv.py,sha256=7vWfPS2J54U9ZKFQiNJ_l86PvITb0MQ61Fy4Fzmf1d4,17479
|
|
3
|
+
adv_optm/optim/Adopt_adv.py,sha256=NXbtPrGm3tZr06cApi5oEHZ2F1zwss3tRi15SGnrYPc,21426
|
|
4
|
+
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
5
5
|
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
6
|
-
adv_optm/optim/Prodigy_adv.py,sha256=
|
|
7
|
-
adv_optm/optim/Simplified_AdEMAMix.py,sha256=
|
|
6
|
+
adv_optm/optim/Prodigy_adv.py,sha256=0_XG5YnMQTv-zJysJHlJniSo5kGYdX3p3o1e33HLt78,25897
|
|
7
|
+
adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
|
|
8
8
|
adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
|
|
9
9
|
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
10
10
|
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
@@ -13,8 +13,8 @@ adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
|
13
13
|
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
14
14
|
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
15
15
|
adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
|
|
16
|
-
adv_optm-1.1.
|
|
17
|
-
adv_optm-1.1.
|
|
18
|
-
adv_optm-1.1.
|
|
19
|
-
adv_optm-1.1.
|
|
20
|
-
adv_optm-1.1.
|
|
16
|
+
adv_optm-1.1.3.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
17
|
+
adv_optm-1.1.3.dist-info/METADATA,sha256=IGemhIn9C4Zg9nE5VaiZjVuRqnBGNxlLNaXabRVXG8Y,14019
|
|
18
|
+
adv_optm-1.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
19
|
+
adv_optm-1.1.3.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
20
|
+
adv_optm-1.1.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|