heavyball 1.2.2__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/chainable.py +2 -2
- heavyball/utils.py +30 -12
- {heavyball-1.2.2.dist-info → heavyball-1.2.3.dist-info}/METADATA +1 -1
- heavyball-1.2.3.dist-info/RECORD +8 -0
- heavyball-1.2.2.dist-info/RECORD +0 -8
- {heavyball-1.2.2.dist-info → heavyball-1.2.3.dist-info}/LICENSE +0 -0
- {heavyball-1.2.2.dist-info → heavyball-1.2.3.dist-info}/WHEEL +0 -0
- {heavyball-1.2.2.dist-info → heavyball-1.2.3.dist-info}/top_level.txt +0 -0
heavyball/chainable.py
CHANGED
@@ -183,8 +183,8 @@ def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
183
183
|
@no_state
|
184
184
|
def update_by_schedule_free(group, update, grad, param, z):
|
185
185
|
group['weight_sum'] = utils.schedule_free_(group['lr'], group['weight_lr_power'], group.get('weight_sum', 0),
|
186
|
-
utils.get_beta1(group), param, z, update,
|
187
|
-
group['weight_decay'])
|
186
|
+
utils.get_beta1(group), param, z, update, grad, group['caution'],
|
187
|
+
group['r'], group['step'], group['weight_decay'])
|
188
188
|
raise SkipUpdate
|
189
189
|
|
190
190
|
|
heavyball/utils.py
CHANGED
@@ -61,22 +61,25 @@ def warmup(lr: float, step: int, warmup_steps: int):
|
|
61
61
|
|
62
62
|
|
63
63
|
@decorator_knowngood
|
64
|
-
def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor,
|
65
|
-
beta1: Tensor, decay: float):
|
66
|
-
for op, oz, g_ in zip(p, z, grad):
|
67
|
-
|
68
|
-
p_, z_,
|
64
|
+
def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, update: List[Tensor], lr: Tensor,
|
65
|
+
beta1: Tensor, decay: float, grad: List[Tensor], caution):
|
66
|
+
for op, oz, u_, g_ in zip(p, z, update, grad):
|
67
|
+
u_ = u_.view_as(op)
|
68
|
+
p_, z_, u_ = map(promote, (op, oz, u_))
|
69
69
|
if decay != 0:
|
70
|
-
|
70
|
+
u_ = u_ + p_ * decay
|
71
|
+
if caution:
|
72
|
+
u_ = _compilable_cautioning(u_, g_)
|
71
73
|
p_ = p_.lerp(z_, ckp1)
|
72
|
-
p_ = p_ +
|
73
|
-
z_ = z_ +
|
74
|
+
p_ = p_ + u_ * (lr * (beta1 * (1 - ckp1)) - lr)
|
75
|
+
z_ = z_ + u_ * -lr
|
74
76
|
copy_stochastic_(op, p_)
|
75
77
|
copy_stochastic_(oz, z_)
|
76
78
|
|
77
79
|
|
78
80
|
def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
|
79
|
-
z: List[Tensor],
|
81
|
+
z: List[Tensor], update: List[Tensor], grad: List[Tensor], caution: bool = False, r: float = 0.0,
|
82
|
+
step: int = 0, decay: float = 0.0):
|
80
83
|
weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
|
81
84
|
weight_sum = weight_sum + weight
|
82
85
|
|
@@ -85,9 +88,9 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
|
|
85
88
|
except ZeroDivisionError:
|
86
89
|
ckp1 = 0
|
87
90
|
|
88
|
-
|
91
|
+
update, parameters, z = list_guard(update, parameters, z)
|
89
92
|
lr, ckp1, beta1 = scalar_guard(lr, ckp1, beta1, grad[0])
|
90
|
-
_compilable_schedule_free_(parameters, z, ckp1,
|
93
|
+
_compilable_schedule_free_(parameters, z, ckp1, update, lr, beta1, decay, grad, caution)
|
91
94
|
return weight_sum
|
92
95
|
|
93
96
|
|
@@ -1220,13 +1223,16 @@ def update_triu_(q_state, materialised):
|
|
1220
1223
|
assert shape0 == shape1
|
1221
1224
|
copy_stochastic_(q, m)
|
1222
1225
|
|
1226
|
+
|
1223
1227
|
_warned = set()
|
1224
1228
|
|
1229
|
+
|
1225
1230
|
def warn_once(msg):
|
1226
1231
|
if msg not in _warned:
|
1227
1232
|
warnings.warn(msg)
|
1228
1233
|
_warned.add(msg)
|
1229
1234
|
|
1235
|
+
|
1230
1236
|
def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
|
1231
1237
|
name: str = 'cumulative_prob'):
|
1232
1238
|
group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
|
@@ -1369,4 +1375,16 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
|
|
1369
1375
|
seen_params.clear()
|
1370
1376
|
|
1371
1377
|
for p in parameters:
|
1372
|
-
p.register_post_accumulate_grad_hook(_step)
|
1378
|
+
p.register_post_accumulate_grad_hook(_step)
|
1379
|
+
|
1380
|
+
|
1381
|
+
@decorator_knowngood
|
1382
|
+
def _compilable_caution_no_scale(g: Tensor, update: Tensor):
|
1383
|
+
mask = g.signbit() ^ update.signbit() # "Mask if they point in different directions"
|
1384
|
+
update = update.masked_fill(mask, 0)
|
1385
|
+
return update
|
1386
|
+
|
1387
|
+
|
1388
|
+
def disable_caution_scaling():
|
1389
|
+
global _compilable_cautioning
|
1390
|
+
_compilable_cautioning = _compilable_caution_no_scale
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
|
2
|
+
heavyball/chainable.py,sha256=u9w2z_aSslcokWVCiiXQJ8GSPlOhgrOFUYAwt2JfTzI,21100
|
3
|
+
heavyball/utils.py,sha256=I2zfiB_-EP35LYr-vLyxPNl8_uJo2se3Id0IWjZeVjg,47951
|
4
|
+
heavyball-1.2.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.2.3.dist-info/METADATA,sha256=EtW_3QIUKrKpyYUfXmGQm3_EpZkr8oQyow7gAyC4Ges,12022
|
6
|
+
heavyball-1.2.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.2.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.2.3.dist-info/RECORD,,
|
heavyball-1.2.2.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
|
2
|
-
heavyball/chainable.py,sha256=O3vBuwDKSkcl6yGrcEgSylqZ6htjRTg1NIA1sNY4KcA,21076
|
3
|
-
heavyball/utils.py,sha256=n80bsTZ2NUD4L3YERaC7ydKaQzW4kSDNoBTLF0DfE1g,47393
|
4
|
-
heavyball-1.2.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.2.2.dist-info/METADATA,sha256=k_qyB9aR8PpreGu_XxNv9z5pyTnH1TMcA8Ah3k1hUgQ,12022
|
6
|
-
heavyball-1.2.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.2.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.2.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|