torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Module, apply
|
|
6
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
|
+
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
8
|
+
|
|
9
|
+
class MatrixMomentum(Module):
|
|
10
|
+
"""
|
|
11
|
+
May be useful for ill conditioned stochastic quadratic objectives but I need to test this.
|
|
12
|
+
Evaluates hessian vector product on each step (via finite difference or autograd).
|
|
13
|
+
|
|
14
|
+
`mu` is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable.
|
|
15
|
+
|
|
16
|
+
Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
|
|
17
|
+
"""
|
|
18
|
+
def __init__(self, mu=0.1, beta:float=1, hvp_mode: Literal['autograd', 'forward', 'central'] = 'forward', h=1e-3, hvp_tfm=None):
|
|
19
|
+
defaults = dict(mu=mu, beta=beta, hvp_mode=hvp_mode, h=h)
|
|
20
|
+
super().__init__(defaults)
|
|
21
|
+
|
|
22
|
+
if hvp_tfm is not None:
|
|
23
|
+
self.set_child('hvp_tfm', hvp_tfm)
|
|
24
|
+
|
|
25
|
+
@torch.no_grad
|
|
26
|
+
def step(self, vars):
|
|
27
|
+
assert vars.closure is not None
|
|
28
|
+
prev_update = self.get_state('prev_update', params=vars.params, cls=TensorList)
|
|
29
|
+
hvp_mode = self.settings[vars.params[0]]['hvp_mode']
|
|
30
|
+
h = self.settings[vars.params[0]]['h']
|
|
31
|
+
|
|
32
|
+
mu,beta = self.get_settings('mu','beta', params=vars.params, cls=NumberList)
|
|
33
|
+
|
|
34
|
+
if hvp_mode == 'autograd':
|
|
35
|
+
with torch.enable_grad():
|
|
36
|
+
grad = vars.get_grad(create_graph=True)
|
|
37
|
+
hvp_ = TensorList(hvp(vars.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
|
|
38
|
+
|
|
39
|
+
elif hvp_mode == 'forward':
|
|
40
|
+
vars.get_grad()
|
|
41
|
+
l, hvp_ = hvp_fd_forward(vars.closure, vars.params, vec=prev_update, g_0=vars.grad, h=h, normalize=True)
|
|
42
|
+
if vars.loss_approx is None: vars.loss_approx = l
|
|
43
|
+
|
|
44
|
+
elif hvp_mode == 'central':
|
|
45
|
+
l, hvp_ = hvp_fd_central(vars.closure, vars.params, vec=prev_update, h=h, normalize=True)
|
|
46
|
+
if vars.loss_approx is None: vars.loss_approx = l
|
|
47
|
+
|
|
48
|
+
else:
|
|
49
|
+
raise ValueError(hvp_mode)
|
|
50
|
+
|
|
51
|
+
if 'hvp_tfm' in self.children:
|
|
52
|
+
hvp_ = TensorList(apply(self.children['hvp_tfm'], hvp_, params=vars.params, grads=vars.grad, vars=vars))
|
|
53
|
+
|
|
54
|
+
update = TensorList(vars.get_update())
|
|
55
|
+
|
|
56
|
+
hvp_ = as_tensorlist(hvp_)
|
|
57
|
+
update.add_(prev_update - hvp_*mu)
|
|
58
|
+
prev_update.set_(update * beta)
|
|
59
|
+
vars.update = update
|
|
60
|
+
return vars
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class AdaptiveMatrixMomentum(Module):
|
|
64
|
+
"""
|
|
65
|
+
Mu here is estimated as ||s_k||/||y_k||.
|
|
66
|
+
"""
|
|
67
|
+
def __init__(self, mu_mul:float=1, beta:float=1, eps=1e-4, hvp_mode: Literal['autograd', 'forward', 'central'] = 'forward', h=1e-3, hvp_tfm=None):
|
|
68
|
+
defaults = dict(mu_mul=mu_mul, beta=beta, hvp_mode=hvp_mode, h=h, eps=eps)
|
|
69
|
+
super().__init__(defaults)
|
|
70
|
+
|
|
71
|
+
if hvp_tfm is not None:
|
|
72
|
+
self.set_child('hvp_tfm', hvp_tfm)
|
|
73
|
+
|
|
74
|
+
@torch.no_grad
|
|
75
|
+
def step(self, vars):
|
|
76
|
+
assert vars.closure is not None
|
|
77
|
+
prev_update, prev_params, prev_grad = self.get_state('prev_update', 'prev_params', 'prev_grad', params=vars.params, cls=TensorList)
|
|
78
|
+
|
|
79
|
+
settings = self.settings[vars.params[0]]
|
|
80
|
+
hvp_mode = settings['hvp_mode']
|
|
81
|
+
h = settings['h']
|
|
82
|
+
eps = settings['eps']
|
|
83
|
+
|
|
84
|
+
mu_mul, beta = self.get_settings('mu_mul','beta', params=vars.params, cls=NumberList)
|
|
85
|
+
|
|
86
|
+
if hvp_mode == 'autograd':
|
|
87
|
+
with torch.enable_grad():
|
|
88
|
+
grad = vars.get_grad(create_graph=True)
|
|
89
|
+
hvp_ = TensorList(hvp(vars.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
|
|
90
|
+
|
|
91
|
+
elif hvp_mode == 'forward':
|
|
92
|
+
vars.get_grad()
|
|
93
|
+
l, hvp_ = hvp_fd_forward(vars.closure, vars.params, vec=prev_update, g_0=vars.grad, h=h, normalize=True)
|
|
94
|
+
if vars.loss_approx is None: vars.loss_approx = l
|
|
95
|
+
|
|
96
|
+
elif hvp_mode == 'central':
|
|
97
|
+
l, hvp_ = hvp_fd_central(vars.closure, vars.params, vec=prev_update, h=h, normalize=True)
|
|
98
|
+
if vars.loss_approx is None: vars.loss_approx = l
|
|
99
|
+
|
|
100
|
+
else:
|
|
101
|
+
raise ValueError(hvp_mode)
|
|
102
|
+
|
|
103
|
+
if 'hvp_tfm' in self.children:
|
|
104
|
+
hvp_ = TensorList(apply(self.children['hvp_tfm'], hvp_, params=vars.params, grads=vars.grad, vars=vars))
|
|
105
|
+
|
|
106
|
+
# adaptive part
|
|
107
|
+
update = TensorList(vars.get_update())
|
|
108
|
+
|
|
109
|
+
s_k = vars.params - prev_params
|
|
110
|
+
prev_params.copy_(vars.params)
|
|
111
|
+
|
|
112
|
+
assert vars.grad is not None
|
|
113
|
+
y_k = vars.grad - prev_grad
|
|
114
|
+
prev_grad.copy_(vars.grad)
|
|
115
|
+
|
|
116
|
+
ada_mu = (s_k.global_vector_norm() / (y_k.global_vector_norm() + eps)) * mu_mul
|
|
117
|
+
|
|
118
|
+
# matrix momentum uppdate
|
|
119
|
+
hvp_ = as_tensorlist(hvp_)
|
|
120
|
+
update.add_(prev_update - hvp_*ada_mu)
|
|
121
|
+
prev_update.set_(update * beta)
|
|
122
|
+
vars.update = update
|
|
123
|
+
return vars
|
|
124
|
+
|
|
@@ -1,106 +1,43 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...
|
|
6
|
-
from ...
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
except this also supports dampening.
|
|
45
|
-
|
|
46
|
-
Args:
|
|
47
|
-
decay (float, optional): momentum decay. Defaults to 0.9.
|
|
48
|
-
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
49
|
-
"""
|
|
50
|
-
def __init__(self, decay: float = 0.9, dampening: float = 0, ):
|
|
51
|
-
defaults = dict(momentum = decay, dampening = dampening)
|
|
52
|
-
super().__init__(defaults)
|
|
53
|
-
|
|
54
|
-
@torch.no_grad
|
|
55
|
-
def _update(self, vars, ascent):
|
|
56
|
-
velocity = self.get_state_key('velocity')
|
|
57
|
-
settings = self.get_all_group_keys()
|
|
58
|
-
_nesterov_step_(ascent, velocity, settings['momentum'], settings['dampening'])
|
|
59
|
-
return ascent
|
|
60
|
-
|
|
61
|
-
class GradientAveraging(OptimizerModule):
|
|
62
|
-
"""Averages last 2 gradients (TODO)"""
|
|
63
|
-
def __init__(self, dampening: float = 0, ):
|
|
64
|
-
defaults = dict(dampening = dampening)
|
|
65
|
-
super().__init__(defaults)
|
|
66
|
-
|
|
67
|
-
@torch.no_grad
|
|
68
|
-
def _update(self, vars, ascent):
|
|
69
|
-
velocity = self.get_state_key('velocity')
|
|
70
|
-
dampening = self.get_group_key('dampening')
|
|
71
|
-
|
|
72
|
-
new_direction = ascent + velocity * (1-dampening)
|
|
73
|
-
velocity.copy_(ascent)
|
|
74
|
-
|
|
75
|
-
return new_direction
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
class RandomCoordinateMomentum(OptimizerModule):
|
|
79
|
-
"""Only uses `p` random coordinates of the new update. Other coordinates remain from previous update.
|
|
80
|
-
This works but I don't know if it is any good.
|
|
81
|
-
|
|
82
|
-
Args:
|
|
83
|
-
p (float, optional): probability to update velocity with a new weigh value. Defaults to 0.1.
|
|
84
|
-
nesterov (bool, optional): if False, update uses delayed momentum. Defaults to True.
|
|
85
|
-
"""
|
|
86
|
-
def __init__(self, p: float = 0.1, nesterov=True):
|
|
87
|
-
defaults = dict(p=p)
|
|
88
|
-
super().__init__(defaults)
|
|
89
|
-
self.nesterov = nesterov
|
|
90
|
-
|
|
91
|
-
@torch.no_grad
|
|
92
|
-
def _update(self, vars, ascent):
|
|
93
|
-
velocity = self.get_state_key('velocity', init = ascent)
|
|
94
|
-
settings = self.get_all_group_keys()
|
|
95
|
-
|
|
96
|
-
# pick p veclocity indexes to update with the new ascent direction
|
|
97
|
-
indexes = ascent.bernoulli_like(settings['p']).as_bool()
|
|
98
|
-
|
|
99
|
-
if self.nesterov:
|
|
100
|
-
# update the velocity at those indexes
|
|
101
|
-
velocity.masked_set_(mask = indexes, value = ascent)
|
|
102
|
-
return velocity.clone()
|
|
103
|
-
|
|
104
|
-
new_ascent = velocity.clone()
|
|
105
|
-
velocity.masked_set_(mask = indexes, value = ascent)
|
|
106
|
-
return new_ascent
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Target, Transform
|
|
6
|
+
from ...utils import NumberList, TensorList
|
|
7
|
+
from .ema import EMA
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class HeavyBall(EMA):
|
|
11
|
+
def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update', target: Target = 'update'):
|
|
12
|
+
super().__init__(momentum=momentum, dampening=dampening, debiased=debiased, lerp=lerp, ema_init=ema_init, target=target)
|
|
13
|
+
|
|
14
|
+
def nag_(
|
|
15
|
+
tensors_: TensorList,
|
|
16
|
+
velocity_: TensorList,
|
|
17
|
+
momentum: float | NumberList,
|
|
18
|
+
dampening: float | NumberList,
|
|
19
|
+
lerp: bool = False,
|
|
20
|
+
):
|
|
21
|
+
"""Nesterov momentum.
|
|
22
|
+
|
|
23
|
+
Returns `tensors_`"""
|
|
24
|
+
if lerp: velocity_.lerp_(tensors_, 1 - momentum)
|
|
25
|
+
else: velocity_.add_(tensors_).mul_(momentum)
|
|
26
|
+
|
|
27
|
+
tensors_ += velocity_.lazy_mul(1 - dampening)
|
|
28
|
+
|
|
29
|
+
return tensors_
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class NAG(Transform):
|
|
33
|
+
def __init__(self, momentum:float=0.9, dampening:float=0, lerp=False, target: Target = 'update'):
|
|
34
|
+
defaults = dict(momentum=momentum,dampening=dampening, lerp=lerp)
|
|
35
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
36
|
+
|
|
37
|
+
@torch.no_grad
|
|
38
|
+
def transform(self, tensors, params, grads, vars):
|
|
39
|
+
velocity = self.get_state('velocity', params=params, cls=TensorList)
|
|
40
|
+
lerp = self.settings[params[0]]['lerp']
|
|
41
|
+
|
|
42
|
+
momentum,dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
|
|
43
|
+
return nag_(TensorList(tensors), velocity_=velocity,momentum=momentum,dampening=dampening,lerp=lerp)
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from .accumulate import (
|
|
2
|
+
AccumulateMaximum,
|
|
3
|
+
AccumulateMean,
|
|
4
|
+
AccumulateMinimum,
|
|
5
|
+
AccumulateProduct,
|
|
6
|
+
AccumulateSum,
|
|
7
|
+
)
|
|
8
|
+
from .binary import (
|
|
9
|
+
Add,
|
|
10
|
+
BinaryOperation,
|
|
11
|
+
Clip,
|
|
12
|
+
CopyMagnitude,
|
|
13
|
+
CopySign,
|
|
14
|
+
Div,
|
|
15
|
+
Graft,
|
|
16
|
+
GraftToUpdate,
|
|
17
|
+
GramSchimdt,
|
|
18
|
+
Maximum,
|
|
19
|
+
Minimum,
|
|
20
|
+
Mul,
|
|
21
|
+
Pow,
|
|
22
|
+
RCopySign,
|
|
23
|
+
RDiv,
|
|
24
|
+
RGraft,
|
|
25
|
+
RPow,
|
|
26
|
+
RSub,
|
|
27
|
+
Sub,
|
|
28
|
+
Threshold,
|
|
29
|
+
)
|
|
30
|
+
from .debug import PrintShape, PrintUpdate
|
|
31
|
+
from .misc import (
|
|
32
|
+
DivByLoss,
|
|
33
|
+
Dropout,
|
|
34
|
+
FillLoss,
|
|
35
|
+
GradientAccumulation,
|
|
36
|
+
GradSign,
|
|
37
|
+
GraftGradToUpdate,
|
|
38
|
+
GraftToGrad,
|
|
39
|
+
GraftToParams,
|
|
40
|
+
LastAbsoluteRatio,
|
|
41
|
+
LastDifference,
|
|
42
|
+
LastGradDifference,
|
|
43
|
+
LastProduct,
|
|
44
|
+
LastRatio,
|
|
45
|
+
MulByLoss,
|
|
46
|
+
Multistep,
|
|
47
|
+
NegateOnLossIncrease,
|
|
48
|
+
NoiseSign,
|
|
49
|
+
Previous,
|
|
50
|
+
Relative,
|
|
51
|
+
Sequential,
|
|
52
|
+
UpdateSign,
|
|
53
|
+
WeightDropout,
|
|
54
|
+
)
|
|
55
|
+
from .multi import (
|
|
56
|
+
ClipModules,
|
|
57
|
+
DivModules,
|
|
58
|
+
GraftModules,
|
|
59
|
+
LerpModules,
|
|
60
|
+
MultiOperation,
|
|
61
|
+
PowModules,
|
|
62
|
+
SubModules,
|
|
63
|
+
)
|
|
64
|
+
from .reduce import (
|
|
65
|
+
MaximumModules,
|
|
66
|
+
Mean,
|
|
67
|
+
MinimumModules,
|
|
68
|
+
Prod,
|
|
69
|
+
ReduceOperation,
|
|
70
|
+
Sum,
|
|
71
|
+
WeightedMean,
|
|
72
|
+
WeightedSum,
|
|
73
|
+
)
|
|
74
|
+
from .split import Split
|
|
75
|
+
from .switch import Alternate, Switch
|
|
76
|
+
from .unary import (
|
|
77
|
+
Abs,
|
|
78
|
+
CustomUnaryOperation,
|
|
79
|
+
Exp,
|
|
80
|
+
NanToNum,
|
|
81
|
+
Negate,
|
|
82
|
+
Reciprocal,
|
|
83
|
+
Sign,
|
|
84
|
+
Sqrt,
|
|
85
|
+
UnaryLambda,
|
|
86
|
+
UnaryParameterwiseLambda,
|
|
87
|
+
)
|
|
88
|
+
from .utility import (
|
|
89
|
+
Clone,
|
|
90
|
+
Fill,
|
|
91
|
+
Grad,
|
|
92
|
+
GradToNone,
|
|
93
|
+
Identity,
|
|
94
|
+
NoOp,
|
|
95
|
+
Ones,
|
|
96
|
+
Params,
|
|
97
|
+
Randn,
|
|
98
|
+
RandomSample,
|
|
99
|
+
Uniform,
|
|
100
|
+
Update,
|
|
101
|
+
UpdateToNone,
|
|
102
|
+
Zeros,
|
|
103
|
+
)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from operator import itemgetter
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Target, Transform
|
|
8
|
+
from ...utils import TensorList, NumberList
|
|
9
|
+
|
|
10
|
+
class AccumulateSum(Transform):
|
|
11
|
+
def __init__(self, decay: float = 0, target: Target = 'update',):
|
|
12
|
+
defaults = dict(decay=decay)
|
|
13
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
14
|
+
|
|
15
|
+
@torch.no_grad
|
|
16
|
+
def transform(self, tensors, params, grads, vars):
|
|
17
|
+
sum = self.get_state('sum', params=params, cls=TensorList)
|
|
18
|
+
decay = self.get_settings('decay', params=params, cls=NumberList)
|
|
19
|
+
return sum.add_(tensors).lazy_mul(1-decay, clone=True)
|
|
20
|
+
|
|
21
|
+
class AccumulateMean(Transform):
|
|
22
|
+
def __init__(self, decay: float = 0, target: Target = 'update',):
|
|
23
|
+
defaults = dict(decay=decay)
|
|
24
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
25
|
+
|
|
26
|
+
@torch.no_grad
|
|
27
|
+
def transform(self, tensors, params, grads, vars):
|
|
28
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
29
|
+
mean = self.get_state('mean', params=params, cls=TensorList)
|
|
30
|
+
decay = self.get_settings('decay', params=params, cls=NumberList)
|
|
31
|
+
return mean.add_(tensors).lazy_mul(1-decay, clone=True).div_(step)
|
|
32
|
+
|
|
33
|
+
class AccumulateProduct(Transform):
|
|
34
|
+
def __init__(self, decay: float = 0, target: Target = 'update',):
|
|
35
|
+
defaults = dict(decay=decay)
|
|
36
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
37
|
+
|
|
38
|
+
@torch.no_grad
|
|
39
|
+
def transform(self, tensors, params, grads, vars):
|
|
40
|
+
prod = self.get_state('prod', params=params, cls=TensorList)
|
|
41
|
+
decay = self.get_settings('decay', params=params, cls=NumberList)
|
|
42
|
+
return prod.mul_(tensors).lazy_mul(1-decay, clone=True)
|
|
43
|
+
|
|
44
|
+
class AccumulateMaximum(Transform):
|
|
45
|
+
def __init__(self, decay: float = 0, target: Target = 'update',):
|
|
46
|
+
defaults = dict(decay=decay)
|
|
47
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
48
|
+
|
|
49
|
+
@torch.no_grad
|
|
50
|
+
def transform(self, tensors, params, grads, vars):
|
|
51
|
+
maximum = self.get_state('maximum', params=params, cls=TensorList)
|
|
52
|
+
decay = self.get_settings('decay', params=params, cls=NumberList)
|
|
53
|
+
return maximum.maximum_(tensors).lazy_mul(1-decay, clone=True)
|
|
54
|
+
|
|
55
|
+
class AccumulateMinimum(Transform):
|
|
56
|
+
def __init__(self, decay: float = 0, target: Target = 'update',):
|
|
57
|
+
defaults = dict(decay=decay)
|
|
58
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
59
|
+
|
|
60
|
+
@torch.no_grad
|
|
61
|
+
def transform(self, tensors, params, grads, vars):
|
|
62
|
+
minimum = self.get_state('minimum', params=params, cls=TensorList)
|
|
63
|
+
decay = self.get_settings('decay', params=params, cls=NumberList)
|
|
64
|
+
return minimum.minimum_(tensors).lazy_mul(1-decay, clone=True)
|
|
65
|
+
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
#pyright: reportIncompatibleMethodOverride=false
|
|
2
|
+
""""""
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import Iterable, Sequence
|
|
5
|
+
from operator import itemgetter
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ...core import Chainable, Module, Target, Vars, maybe_chain
|
|
11
|
+
from ...utils import TensorList, tensorlist
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BinaryOperation(Module, ABC):
|
|
15
|
+
"""Base class for operations that use update as the first operand. This is an abstract class, subclass it and override `transform` method to use it."""
|
|
16
|
+
def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
|
|
17
|
+
super().__init__(defaults=defaults)
|
|
18
|
+
|
|
19
|
+
self.operands = {}
|
|
20
|
+
for k,v in operands.items():
|
|
21
|
+
|
|
22
|
+
if isinstance(v, (Module, Sequence)):
|
|
23
|
+
self.set_child(k, v)
|
|
24
|
+
self.operands[k] = self.children[k]
|
|
25
|
+
else:
|
|
26
|
+
self.operands[k] = v
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def transform(self, vars: Vars, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
|
|
30
|
+
"""applies the operation to operands"""
|
|
31
|
+
raise NotImplementedError
|
|
32
|
+
|
|
33
|
+
@torch.no_grad
|
|
34
|
+
def step(self, vars: Vars) -> Vars:
|
|
35
|
+
# pass cloned update to all module operands
|
|
36
|
+
processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()
|
|
37
|
+
|
|
38
|
+
for k,v in self.operands.items():
|
|
39
|
+
if k in self.children:
|
|
40
|
+
v: Module
|
|
41
|
+
updated_vars = v.step(vars.clone(clone_update=True))
|
|
42
|
+
processed_operands[k] = updated_vars.get_update()
|
|
43
|
+
vars.update_attrs_from_clone_(updated_vars) # update loss, grad, etc if this module calculated them
|
|
44
|
+
|
|
45
|
+
transformed = self.transform(vars, update=vars.get_update(), **processed_operands)
|
|
46
|
+
vars.update = list(transformed)
|
|
47
|
+
return vars
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class Add(BinaryOperation):
|
|
51
|
+
def __init__(self, other: Chainable | float, alpha: float = 1):
|
|
52
|
+
defaults = dict(alpha=alpha)
|
|
53
|
+
super().__init__(defaults, other=other)
|
|
54
|
+
|
|
55
|
+
@torch.no_grad
|
|
56
|
+
def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
57
|
+
if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.settings[vars.params[0]]['alpha'])
|
|
58
|
+
else: torch._foreach_add_(update, other, alpha=self.settings[vars.params[0]]['alpha'])
|
|
59
|
+
return update
|
|
60
|
+
|
|
61
|
+
class Sub(BinaryOperation):
|
|
62
|
+
def __init__(self, other: Chainable | float, alpha: float = 1):
|
|
63
|
+
defaults = dict(alpha=alpha)
|
|
64
|
+
super().__init__(defaults, other=other)
|
|
65
|
+
|
|
66
|
+
@torch.no_grad
|
|
67
|
+
def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
68
|
+
if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.settings[vars.params[0]]['alpha'])
|
|
69
|
+
else: torch._foreach_sub_(update, other, alpha=self.settings[vars.params[0]]['alpha'])
|
|
70
|
+
return update
|
|
71
|
+
|
|
72
|
+
class RSub(BinaryOperation):
|
|
73
|
+
def __init__(self, other: Chainable | float):
|
|
74
|
+
super().__init__({}, other=other)
|
|
75
|
+
|
|
76
|
+
@torch.no_grad
|
|
77
|
+
def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
78
|
+
return other - TensorList(update)
|
|
79
|
+
|
|
80
|
+
class Mul(BinaryOperation):
|
|
81
|
+
def __init__(self, other: Chainable | float):
|
|
82
|
+
super().__init__({}, other=other)
|
|
83
|
+
|
|
84
|
+
@torch.no_grad
|
|
85
|
+
def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
86
|
+
torch._foreach_mul_(update, other)
|
|
87
|
+
return update
|
|
88
|
+
|
|
89
|
+
class Div(BinaryOperation):
|
|
90
|
+
def __init__(self, other: Chainable | float):
|
|
91
|
+
super().__init__({}, other=other)
|
|
92
|
+
|
|
93
|
+
@torch.no_grad
|
|
94
|
+
def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
95
|
+
torch._foreach_div_(update, other)
|
|
96
|
+
return update
|
|
97
|
+
|
|
98
|
+
class RDiv(BinaryOperation):
|
|
99
|
+
def __init__(self, other: Chainable | float):
|
|
100
|
+
super().__init__({}, other=other)
|
|
101
|
+
|
|
102
|
+
@torch.no_grad
|
|
103
|
+
def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
104
|
+
return other / TensorList(update)
|
|
105
|
+
|
|
106
|
+
class Pow(BinaryOperation):
|
|
107
|
+
def __init__(self, exponent: Chainable | float):
|
|
108
|
+
super().__init__({}, exponent=exponent)
|
|
109
|
+
|
|
110
|
+
@torch.no_grad
|
|
111
|
+
def transform(self, vars, update: list[torch.Tensor], exponent: float | list[torch.Tensor]):
|
|
112
|
+
torch._foreach_pow_(update, exponent)
|
|
113
|
+
return update
|
|
114
|
+
|
|
115
|
+
class RPow(BinaryOperation):
|
|
116
|
+
def __init__(self, other: Chainable | float):
|
|
117
|
+
super().__init__({}, other=other)
|
|
118
|
+
|
|
119
|
+
@torch.no_grad
|
|
120
|
+
def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
121
|
+
if isinstance(other, (int, float)): return torch._foreach_pow(other, update) # no in-place
|
|
122
|
+
torch._foreach_pow_(other, update)
|
|
123
|
+
return other
|
|
124
|
+
|
|
125
|
+
class Lerp(BinaryOperation):
|
|
126
|
+
def __init__(self, end: Chainable, weight: float):
|
|
127
|
+
defaults = dict(weight=weight)
|
|
128
|
+
super().__init__(defaults, end=end)
|
|
129
|
+
|
|
130
|
+
@torch.no_grad
|
|
131
|
+
def transform(self, vars, update: list[torch.Tensor], end: list[torch.Tensor]):
|
|
132
|
+
torch._foreach_lerp_(update, end, weight=self.get_settings('weight',params=vars))
|
|
133
|
+
return update
|
|
134
|
+
|
|
135
|
+
class CopySign(BinaryOperation):
|
|
136
|
+
def __init__(self, other: Chainable):
|
|
137
|
+
super().__init__({}, other=other)
|
|
138
|
+
|
|
139
|
+
@torch.no_grad
|
|
140
|
+
def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
141
|
+
return [u.copysign_(o) for u, o in zip(update, other)]
|
|
142
|
+
|
|
143
|
+
class RCopySign(BinaryOperation):
|
|
144
|
+
def __init__(self, other: Chainable):
|
|
145
|
+
super().__init__({}, other=other)
|
|
146
|
+
|
|
147
|
+
@torch.no_grad
|
|
148
|
+
def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
149
|
+
return [o.copysign_(u) for u, o in zip(update, other)]
|
|
150
|
+
CopyMagnitude = RCopySign
|
|
151
|
+
|
|
152
|
+
class Clip(BinaryOperation):
|
|
153
|
+
def __init__(self, min: float | Chainable | None = None, max: float | Chainable | None = None):
|
|
154
|
+
super().__init__({}, min=min, max=max)
|
|
155
|
+
|
|
156
|
+
@torch.no_grad
|
|
157
|
+
def transform(self, vars, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
|
|
158
|
+
return TensorList(update).clamp_(min=min, max=max)
|
|
159
|
+
|
|
160
|
+
class MirroredClip(BinaryOperation):
|
|
161
|
+
"""clip by -value, value"""
|
|
162
|
+
def __init__(self, value: float | Chainable):
|
|
163
|
+
super().__init__({}, value=value)
|
|
164
|
+
|
|
165
|
+
@torch.no_grad
|
|
166
|
+
def transform(self, vars, update: list[torch.Tensor], value: float | list[torch.Tensor]):
|
|
167
|
+
min = -value if isinstance(value, (int,float)) else [-v for v in value]
|
|
168
|
+
return TensorList(update).clamp_(min=min, max=value)
|
|
169
|
+
|
|
170
|
+
class Graft(BinaryOperation):
|
|
171
|
+
"""use direction from update and magnitude from `magnitude` module"""
|
|
172
|
+
def __init__(self, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
|
|
173
|
+
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
174
|
+
super().__init__(defaults, magnitude=magnitude)
|
|
175
|
+
|
|
176
|
+
@torch.no_grad
|
|
177
|
+
def transform(self, vars, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
|
|
178
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[vars.params[0]])
|
|
179
|
+
return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
180
|
+
|
|
181
|
+
class RGraft(BinaryOperation):
|
|
182
|
+
"""use direction from `direction` module and magnitude from update"""
|
|
183
|
+
|
|
184
|
+
def __init__(self, direction: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
|
|
185
|
+
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
186
|
+
super().__init__(defaults, direction=direction)
|
|
187
|
+
|
|
188
|
+
@torch.no_grad
|
|
189
|
+
def transform(self, vars, update: list[torch.Tensor], direction: list[torch.Tensor]):
|
|
190
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[vars.params[0]])
|
|
191
|
+
return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
192
|
+
|
|
193
|
+
GraftToUpdate = RGraft
|
|
194
|
+
|
|
195
|
+
class Maximum(BinaryOperation):
|
|
196
|
+
def __init__(self, other: Chainable):
|
|
197
|
+
super().__init__({}, other=other)
|
|
198
|
+
|
|
199
|
+
@torch.no_grad
|
|
200
|
+
def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
201
|
+
torch._foreach_maximum_(update, other)
|
|
202
|
+
return update
|
|
203
|
+
|
|
204
|
+
class Minimum(BinaryOperation):
|
|
205
|
+
def __init__(self, other: Chainable):
|
|
206
|
+
super().__init__({}, other=other)
|
|
207
|
+
|
|
208
|
+
@torch.no_grad
|
|
209
|
+
def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
210
|
+
torch._foreach_minimum_(update, other)
|
|
211
|
+
return update
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class GramSchimdt(BinaryOperation):
|
|
215
|
+
"""makes update orthonormal to `other`"""
|
|
216
|
+
def __init__(self, other: Chainable):
|
|
217
|
+
super().__init__({}, other=other)
|
|
218
|
+
|
|
219
|
+
@torch.no_grad
|
|
220
|
+
def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
221
|
+
update = TensorList(update); other = TensorList(other)
|
|
222
|
+
return update - (other*update) / ((other*other) + 1e-8)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class Threshold(BinaryOperation):
|
|
226
|
+
"""update above/below threshold, value at and below"""
|
|
227
|
+
def __init__(self, threshold: Chainable | float, value: Chainable | float, update_above: bool):
|
|
228
|
+
defaults = dict(update_above=update_above)
|
|
229
|
+
super().__init__(defaults, threshold=threshold, value=value)
|
|
230
|
+
|
|
231
|
+
@torch.no_grad
|
|
232
|
+
def transform(self, vars, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
|
|
233
|
+
update_above = self.settings[vars.params[0]]['update_above']
|
|
234
|
+
update = TensorList(update)
|
|
235
|
+
if update_above:
|
|
236
|
+
if isinstance(value, list): return update.where_(update>threshold, value)
|
|
237
|
+
return update.masked_fill_(update<=threshold, value)
|
|
238
|
+
|
|
239
|
+
if isinstance(value, list): return update.where_(update<threshold, value)
|
|
240
|
+
return update.masked_fill_(update>=threshold, value)
|