heavyball 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/chainable.py +10 -9
- heavyball/utils.py +35 -18
- {heavyball-1.2.2.dist-info → heavyball-1.3.0.dist-info}/METADATA +1 -1
- heavyball-1.3.0.dist-info/RECORD +8 -0
- heavyball-1.2.2.dist-info/RECORD +0 -8
- {heavyball-1.2.2.dist-info → heavyball-1.3.0.dist-info}/LICENSE +0 -0
- {heavyball-1.2.2.dist-info → heavyball-1.3.0.dist-info}/WHEEL +0 -0
- {heavyball-1.2.2.dist-info → heavyball-1.3.0.dist-info}/top_level.txt +0 -0
heavyball/chainable.py
CHANGED
@@ -183,8 +183,8 @@ def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
183
183
|
@no_state
|
184
184
|
def update_by_schedule_free(group, update, grad, param, z):
|
185
185
|
group['weight_sum'] = utils.schedule_free_(group['lr'], group['weight_lr_power'], group.get('weight_sum', 0),
|
186
|
-
utils.get_beta1(group), param, z, update,
|
187
|
-
group['weight_decay'])
|
186
|
+
utils.get_beta1(group), param, z, update, grad, group['caution'],
|
187
|
+
group['r'], group['step'], group['weight_decay'])
|
188
188
|
raise SkipUpdate
|
189
189
|
|
190
190
|
|
@@ -438,14 +438,15 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
438
438
|
|
439
439
|
for param in p:
|
440
440
|
state = self.state_(param)
|
441
|
-
if 'step'
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
441
|
+
if 'step' in state:
|
442
|
+
step = state['step']
|
443
|
+
elif self.compile_step:
|
444
|
+
step = utils.scalar_guard(0, param)
|
445
|
+
else:
|
446
|
+
step = 0
|
446
447
|
break
|
447
448
|
|
448
|
-
group['step'] = step
|
449
|
+
group['step'] = state['step'] = step = step + 1
|
449
450
|
|
450
451
|
if group['warmup_steps'] and step < group['warmup_steps']:
|
451
452
|
group['lr'] = group['base_lr'] * step / group['warmup_steps']
|
@@ -494,7 +495,7 @@ class BaseOpt(ChainOpt):
|
|
494
495
|
auto_fuse: bool = True
|
495
496
|
|
496
497
|
def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
|
497
|
-
palm: bool = use_default, compile_step: bool = use_default
|
498
|
+
palm: bool = use_default, *fns, compile_step: bool = use_default):
|
498
499
|
if default(update_clipping, self.update_clipping) is None:
|
499
500
|
if fns and self.auto_fuse:
|
500
501
|
args, kwargs = None, None
|
heavyball/utils.py
CHANGED
@@ -61,22 +61,25 @@ def warmup(lr: float, step: int, warmup_steps: int):
|
|
61
61
|
|
62
62
|
|
63
63
|
@decorator_knowngood
|
64
|
-
def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor,
|
65
|
-
beta1: Tensor, decay: float):
|
66
|
-
for op, oz, g_ in zip(p, z, grad):
|
67
|
-
|
68
|
-
p_, z_,
|
64
|
+
def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, update: List[Tensor], lr: Tensor,
|
65
|
+
beta1: Tensor, decay: float, grad: List[Tensor], caution):
|
66
|
+
for op, oz, u_, g_ in zip(p, z, update, grad):
|
67
|
+
u_ = u_.view_as(op)
|
68
|
+
p_, z_, u_ = map(promote, (op, oz, u_))
|
69
69
|
if decay != 0:
|
70
|
-
|
70
|
+
u_ = u_ + p_ * decay
|
71
|
+
if caution:
|
72
|
+
u_ = _compilable_cautioning(u_, g_)
|
71
73
|
p_ = p_.lerp(z_, ckp1)
|
72
|
-
p_ = p_ +
|
73
|
-
z_ = z_ +
|
74
|
+
p_ = p_ + u_ * (lr * (beta1 * (1 - ckp1)) - lr)
|
75
|
+
z_ = z_ + u_ * -lr
|
74
76
|
copy_stochastic_(op, p_)
|
75
77
|
copy_stochastic_(oz, z_)
|
76
78
|
|
77
79
|
|
78
80
|
def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
|
79
|
-
z: List[Tensor],
|
81
|
+
z: List[Tensor], update: List[Tensor], grad: List[Tensor], caution: bool = False, r: float = 0.0,
|
82
|
+
step: int = 0, decay: float = 0.0):
|
80
83
|
weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
|
81
84
|
weight_sum = weight_sum + weight
|
82
85
|
|
@@ -85,9 +88,9 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
|
|
85
88
|
except ZeroDivisionError:
|
86
89
|
ckp1 = 0
|
87
90
|
|
88
|
-
|
91
|
+
update, parameters, z, grad = list_guard(update, parameters, z, grad)
|
89
92
|
lr, ckp1, beta1 = scalar_guard(lr, ckp1, beta1, grad[0])
|
90
|
-
_compilable_schedule_free_(parameters, z, ckp1,
|
93
|
+
_compilable_schedule_free_(parameters, z, ckp1, update, lr, beta1, decay, grad, caution)
|
91
94
|
return weight_sum
|
92
95
|
|
93
96
|
|
@@ -909,13 +912,12 @@ def copy_stochastic_(target: Tensor, source: Tensor):
|
|
909
912
|
@decorator_knowngood
|
910
913
|
def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool,
|
911
914
|
g: List[Optional[Tensor]]):
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
for p32_, u32_, g_, p_ in zip(p32, u32, g, p): # lr is data-dependent -> can't compile a foreach
|
915
|
+
for u_, g_, p_ in zip(u, g, p): # lr is data-dependent -> can't compile a foreach
|
916
|
+
u_ = promote(u_.view_as(p_))
|
917
|
+
p32_ = promote(p_)
|
916
918
|
if caution:
|
917
|
-
|
918
|
-
p32_ = p32_ * (1 - decay * lr) +
|
919
|
+
u_ = _compilable_cautioning(promote(g_), u_)
|
920
|
+
p32_ = p32_ * (1 - decay * lr) + u_ * -lr
|
919
921
|
copy_stochastic_(p_, p32_)
|
920
922
|
|
921
923
|
|
@@ -1220,13 +1222,16 @@ def update_triu_(q_state, materialised):
|
|
1220
1222
|
assert shape0 == shape1
|
1221
1223
|
copy_stochastic_(q, m)
|
1222
1224
|
|
1225
|
+
|
1223
1226
|
_warned = set()
|
1224
1227
|
|
1228
|
+
|
1225
1229
|
def warn_once(msg):
|
1226
1230
|
if msg not in _warned:
|
1227
1231
|
warnings.warn(msg)
|
1228
1232
|
_warned.add(msg)
|
1229
1233
|
|
1234
|
+
|
1230
1235
|
def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
|
1231
1236
|
name: str = 'cumulative_prob'):
|
1232
1237
|
group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
|
@@ -1369,4 +1374,16 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
|
|
1369
1374
|
seen_params.clear()
|
1370
1375
|
|
1371
1376
|
for p in parameters:
|
1372
|
-
p.register_post_accumulate_grad_hook(_step)
|
1377
|
+
p.register_post_accumulate_grad_hook(_step)
|
1378
|
+
|
1379
|
+
|
1380
|
+
@decorator_knowngood
|
1381
|
+
def _compilable_caution_no_scale(g: Tensor, update: Tensor):
|
1382
|
+
mask = g.signbit() ^ update.signbit() # "Mask if they point in different directions"
|
1383
|
+
update = update.masked_fill(mask, 0)
|
1384
|
+
return update
|
1385
|
+
|
1386
|
+
|
1387
|
+
def disable_caution_scaling():
|
1388
|
+
global _compilable_cautioning
|
1389
|
+
_compilable_cautioning = _compilable_caution_no_scale
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
|
2
|
+
heavyball/chainable.py,sha256=U-yxSKYIcnMJFvqQlVBKL9BSpFvUAVPg4FxgI5sN21g,21119
|
3
|
+
heavyball/utils.py,sha256=0bqa2J3oIp3qvxpUsvQgSh5RBkO9fR9WX8nFFWvcLG0,47901
|
4
|
+
heavyball-1.3.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.3.0.dist-info/METADATA,sha256=1zOefT9mm8_SnOReLEAHHJkY10CC9aWquTCanWHS4ww,12022
|
6
|
+
heavyball-1.3.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.3.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.3.0.dist-info/RECORD,,
|
heavyball-1.2.2.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
|
2
|
-
heavyball/chainable.py,sha256=O3vBuwDKSkcl6yGrcEgSylqZ6htjRTg1NIA1sNY4KcA,21076
|
3
|
-
heavyball/utils.py,sha256=n80bsTZ2NUD4L3YERaC7ydKaQzW4kSDNoBTLF0DfE1g,47393
|
4
|
-
heavyball-1.2.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.2.2.dist-info/METADATA,sha256=k_qyB9aR8PpreGu_XxNv9z5pyTnH1TMcA8Ah3k1hUgQ,12022
|
6
|
-
heavyball-1.2.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.2.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.2.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|