heavyball 1.2.3__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/chainable.py +8 -7
- heavyball/utils.py +6 -7
- {heavyball-1.2.3.dist-info → heavyball-1.3.0.dist-info}/METADATA +1 -1
- heavyball-1.3.0.dist-info/RECORD +8 -0
- heavyball-1.2.3.dist-info/RECORD +0 -8
- {heavyball-1.2.3.dist-info → heavyball-1.3.0.dist-info}/LICENSE +0 -0
- {heavyball-1.2.3.dist-info → heavyball-1.3.0.dist-info}/WHEEL +0 -0
- {heavyball-1.2.3.dist-info → heavyball-1.3.0.dist-info}/top_level.txt +0 -0
heavyball/chainable.py
CHANGED
@@ -438,14 +438,15 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
438
438
|
|
439
439
|
for param in p:
|
440
440
|
state = self.state_(param)
|
441
|
-
if 'step'
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
441
|
+
if 'step' in state:
|
442
|
+
step = state['step']
|
443
|
+
elif self.compile_step:
|
444
|
+
step = utils.scalar_guard(0, param)
|
445
|
+
else:
|
446
|
+
step = 0
|
446
447
|
break
|
447
448
|
|
448
|
-
group['step'] = step
|
449
|
+
group['step'] = state['step'] = step = step + 1
|
449
450
|
|
450
451
|
if group['warmup_steps'] and step < group['warmup_steps']:
|
451
452
|
group['lr'] = group['base_lr'] * step / group['warmup_steps']
|
@@ -494,7 +495,7 @@ class BaseOpt(ChainOpt):
|
|
494
495
|
auto_fuse: bool = True
|
495
496
|
|
496
497
|
def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
|
497
|
-
palm: bool = use_default, compile_step: bool = use_default
|
498
|
+
palm: bool = use_default, *fns, compile_step: bool = use_default):
|
498
499
|
if default(update_clipping, self.update_clipping) is None:
|
499
500
|
if fns and self.auto_fuse:
|
500
501
|
args, kwargs = None, None
|
heavyball/utils.py
CHANGED
@@ -88,7 +88,7 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
|
|
88
88
|
except ZeroDivisionError:
|
89
89
|
ckp1 = 0
|
90
90
|
|
91
|
-
update, parameters, z = list_guard(update, parameters, z)
|
91
|
+
update, parameters, z, grad = list_guard(update, parameters, z, grad)
|
92
92
|
lr, ckp1, beta1 = scalar_guard(lr, ckp1, beta1, grad[0])
|
93
93
|
_compilable_schedule_free_(parameters, z, ckp1, update, lr, beta1, decay, grad, caution)
|
94
94
|
return weight_sum
|
@@ -912,13 +912,12 @@ def copy_stochastic_(target: Tensor, source: Tensor):
|
|
912
912
|
@decorator_knowngood
|
913
913
|
def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool,
|
914
914
|
g: List[Optional[Tensor]]):
|
915
|
-
|
916
|
-
|
917
|
-
|
918
|
-
for p32_, u32_, g_, p_ in zip(p32, u32, g, p): # lr is data-dependent -> can't compile a foreach
|
915
|
+
for u_, g_, p_ in zip(u, g, p): # lr is data-dependent -> can't compile a foreach
|
916
|
+
u_ = promote(u_.view_as(p_))
|
917
|
+
p32_ = promote(p_)
|
919
918
|
if caution:
|
920
|
-
|
921
|
-
p32_ = p32_ * (1 - decay * lr) +
|
919
|
+
u_ = _compilable_cautioning(promote(g_), u_)
|
920
|
+
p32_ = p32_ * (1 - decay * lr) + u_ * -lr
|
922
921
|
copy_stochastic_(p_, p32_)
|
923
922
|
|
924
923
|
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
|
2
|
+
heavyball/chainable.py,sha256=U-yxSKYIcnMJFvqQlVBKL9BSpFvUAVPg4FxgI5sN21g,21119
|
3
|
+
heavyball/utils.py,sha256=0bqa2J3oIp3qvxpUsvQgSh5RBkO9fR9WX8nFFWvcLG0,47901
|
4
|
+
heavyball-1.3.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.3.0.dist-info/METADATA,sha256=1zOefT9mm8_SnOReLEAHHJkY10CC9aWquTCanWHS4ww,12022
|
6
|
+
heavyball-1.3.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.3.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.3.0.dist-info/RECORD,,
|
heavyball-1.2.3.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
|
2
|
-
heavyball/chainable.py,sha256=u9w2z_aSslcokWVCiiXQJ8GSPlOhgrOFUYAwt2JfTzI,21100
|
3
|
-
heavyball/utils.py,sha256=I2zfiB_-EP35LYr-vLyxPNl8_uJo2se3Id0IWjZeVjg,47951
|
4
|
-
heavyball-1.2.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.2.3.dist-info/METADATA,sha256=EtW_3QIUKrKpyYUfXmGQm3_EpZkr8oQyow7gAyC4Ges,12022
|
6
|
-
heavyball-1.2.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.2.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.2.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|