torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +54 -21
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +19 -129
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +12 -12
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +67 -17
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +12 -12
- torchzero/modules/quasi_newton/lsr1.py +11 -11
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +254 -47
- torchzero/modules/second_order/newton.py +32 -20
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +21 -21
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,7 @@ from typing import Literal
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import Module,
|
|
5
|
+
from ...core import Module, apply_transform, Chainable
|
|
6
6
|
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
7
|
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
8
8
|
|
|
@@ -13,105 +13,147 @@ class MatrixMomentum(Module):
|
|
|
13
13
|
|
|
14
14
|
`mu` is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable.
|
|
15
15
|
|
|
16
|
-
|
|
16
|
+
Args:
|
|
17
|
+
mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
|
|
18
|
+
beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
|
|
19
|
+
hvp_method (str, optional):
|
|
20
|
+
How to calculate hessian-vector products.
|
|
21
|
+
Exact - "autograd", or finite difference - "forward", "central". Defaults to 'forward'.
|
|
22
|
+
h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
|
|
23
|
+
hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
|
|
24
|
+
|
|
25
|
+
Reference:
|
|
26
|
+
Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
|
|
17
27
|
"""
|
|
18
|
-
|
|
19
|
-
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
mu=0.1,
|
|
32
|
+
beta: float = 1,
|
|
33
|
+
hvp_method: Literal["autograd", "forward", "central"] = "forward",
|
|
34
|
+
h: float = 1e-3,
|
|
35
|
+
hvp_tfm: Chainable | None = None,
|
|
36
|
+
):
|
|
37
|
+
defaults = dict(mu=mu, beta=beta, hvp_method=hvp_method, h=h)
|
|
20
38
|
super().__init__(defaults)
|
|
21
39
|
|
|
22
40
|
if hvp_tfm is not None:
|
|
23
41
|
self.set_child('hvp_tfm', hvp_tfm)
|
|
24
42
|
|
|
25
43
|
@torch.no_grad
|
|
26
|
-
def step(self,
|
|
27
|
-
assert
|
|
28
|
-
prev_update = self.get_state('prev_update',
|
|
29
|
-
|
|
30
|
-
h = self.settings[
|
|
44
|
+
def step(self, var):
|
|
45
|
+
assert var.closure is not None
|
|
46
|
+
prev_update = self.get_state(var.params, 'prev_update', cls=TensorList)
|
|
47
|
+
hvp_method = self.settings[var.params[0]]['hvp_method']
|
|
48
|
+
h = self.settings[var.params[0]]['h']
|
|
31
49
|
|
|
32
|
-
mu,beta = self.get_settings('mu','beta',
|
|
50
|
+
mu,beta = self.get_settings(var.params, 'mu','beta', cls=NumberList)
|
|
33
51
|
|
|
34
|
-
if
|
|
52
|
+
if hvp_method == 'autograd':
|
|
35
53
|
with torch.enable_grad():
|
|
36
|
-
grad =
|
|
37
|
-
hvp_ = TensorList(hvp(
|
|
54
|
+
grad = var.get_grad(create_graph=True)
|
|
55
|
+
hvp_ = TensorList(hvp(var.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
|
|
38
56
|
|
|
39
|
-
elif
|
|
40
|
-
|
|
41
|
-
l, hvp_ = hvp_fd_forward(
|
|
42
|
-
if
|
|
57
|
+
elif hvp_method == 'forward':
|
|
58
|
+
var.get_grad()
|
|
59
|
+
l, hvp_ = hvp_fd_forward(var.closure, var.params, vec=prev_update, g_0=var.grad, h=h, normalize=True)
|
|
60
|
+
if var.loss_approx is None: var.loss_approx = l
|
|
43
61
|
|
|
44
|
-
elif
|
|
45
|
-
l, hvp_ = hvp_fd_central(
|
|
46
|
-
if
|
|
62
|
+
elif hvp_method == 'central':
|
|
63
|
+
l, hvp_ = hvp_fd_central(var.closure, var.params, vec=prev_update, h=h, normalize=True)
|
|
64
|
+
if var.loss_approx is None: var.loss_approx = l
|
|
47
65
|
|
|
48
66
|
else:
|
|
49
|
-
raise ValueError(
|
|
67
|
+
raise ValueError(hvp_method)
|
|
50
68
|
|
|
51
69
|
if 'hvp_tfm' in self.children:
|
|
52
|
-
hvp_ = TensorList(
|
|
70
|
+
hvp_ = TensorList(apply_transform(self.children['hvp_tfm'], hvp_, params=var.params, grads=var.grad, var=var))
|
|
53
71
|
|
|
54
|
-
update = TensorList(
|
|
72
|
+
update = TensorList(var.get_update())
|
|
55
73
|
|
|
56
74
|
hvp_ = as_tensorlist(hvp_)
|
|
57
75
|
update.add_(prev_update - hvp_*mu)
|
|
58
76
|
prev_update.set_(update * beta)
|
|
59
|
-
|
|
60
|
-
return
|
|
77
|
+
var.update = update
|
|
78
|
+
return var
|
|
61
79
|
|
|
62
80
|
|
|
63
81
|
class AdaptiveMatrixMomentum(Module):
|
|
64
82
|
"""
|
|
65
|
-
|
|
83
|
+
May be useful for ill conditioned stochastic quadratic objectives but I need to test this.
|
|
84
|
+
Evaluates hessian vector product on each step (via finite difference or autograd).
|
|
85
|
+
|
|
86
|
+
This version estimates mu via a simple heuristic: ||s||/||y||, where s is parameter difference, y is gradient difference.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
mu_mul (float, optional): multiplier to the estimated mu. Defaults to 1.
|
|
90
|
+
beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
|
|
91
|
+
hvp_method (str, optional):
|
|
92
|
+
How to calculate hessian-vector products.
|
|
93
|
+
Exact - "autograd", or finite difference - "forward", "central". Defaults to 'forward'.
|
|
94
|
+
h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
|
|
95
|
+
hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
|
|
96
|
+
|
|
97
|
+
Reference:
|
|
98
|
+
Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
|
|
66
99
|
"""
|
|
67
|
-
|
|
68
|
-
|
|
100
|
+
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
mu_mul: float = 1,
|
|
104
|
+
beta: float = 1,
|
|
105
|
+
eps=1e-4,
|
|
106
|
+
hvp_method: Literal["autograd", "forward", "central"] = "forward",
|
|
107
|
+
h: float = 1e-3,
|
|
108
|
+
hvp_tfm: Chainable | None = None,
|
|
109
|
+
):
|
|
110
|
+
defaults = dict(mu_mul=mu_mul, beta=beta, hvp_method=hvp_method, h=h, eps=eps)
|
|
69
111
|
super().__init__(defaults)
|
|
70
112
|
|
|
71
113
|
if hvp_tfm is not None:
|
|
72
114
|
self.set_child('hvp_tfm', hvp_tfm)
|
|
73
115
|
|
|
74
116
|
@torch.no_grad
|
|
75
|
-
def step(self,
|
|
76
|
-
assert
|
|
77
|
-
prev_update, prev_params, prev_grad = self.get_state('prev_update', 'prev_params', 'prev_grad',
|
|
117
|
+
def step(self, var):
|
|
118
|
+
assert var.closure is not None
|
|
119
|
+
prev_update, prev_params, prev_grad = self.get_state(var.params, 'prev_update', 'prev_params', 'prev_grad', cls=TensorList)
|
|
78
120
|
|
|
79
|
-
settings = self.settings[
|
|
80
|
-
|
|
121
|
+
settings = self.settings[var.params[0]]
|
|
122
|
+
hvp_method = settings['hvp_method']
|
|
81
123
|
h = settings['h']
|
|
82
124
|
eps = settings['eps']
|
|
83
125
|
|
|
84
|
-
mu_mul, beta = self.get_settings('mu_mul','beta',
|
|
126
|
+
mu_mul, beta = self.get_settings(var.params, 'mu_mul','beta', cls=NumberList)
|
|
85
127
|
|
|
86
|
-
if
|
|
128
|
+
if hvp_method == 'autograd':
|
|
87
129
|
with torch.enable_grad():
|
|
88
|
-
grad =
|
|
89
|
-
hvp_ = TensorList(hvp(
|
|
130
|
+
grad = var.get_grad(create_graph=True)
|
|
131
|
+
hvp_ = TensorList(hvp(var.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
|
|
90
132
|
|
|
91
|
-
elif
|
|
92
|
-
|
|
93
|
-
l, hvp_ = hvp_fd_forward(
|
|
94
|
-
if
|
|
133
|
+
elif hvp_method == 'forward':
|
|
134
|
+
var.get_grad()
|
|
135
|
+
l, hvp_ = hvp_fd_forward(var.closure, var.params, vec=prev_update, g_0=var.grad, h=h, normalize=True)
|
|
136
|
+
if var.loss_approx is None: var.loss_approx = l
|
|
95
137
|
|
|
96
|
-
elif
|
|
97
|
-
l, hvp_ = hvp_fd_central(
|
|
98
|
-
if
|
|
138
|
+
elif hvp_method == 'central':
|
|
139
|
+
l, hvp_ = hvp_fd_central(var.closure, var.params, vec=prev_update, h=h, normalize=True)
|
|
140
|
+
if var.loss_approx is None: var.loss_approx = l
|
|
99
141
|
|
|
100
142
|
else:
|
|
101
|
-
raise ValueError(
|
|
143
|
+
raise ValueError(hvp_method)
|
|
102
144
|
|
|
103
145
|
if 'hvp_tfm' in self.children:
|
|
104
|
-
hvp_ = TensorList(
|
|
146
|
+
hvp_ = TensorList(apply_transform(self.children['hvp_tfm'], hvp_, params=var.params, grads=var.grad, var=var))
|
|
105
147
|
|
|
106
148
|
# adaptive part
|
|
107
|
-
update = TensorList(
|
|
149
|
+
update = TensorList(var.get_update())
|
|
108
150
|
|
|
109
|
-
s_k =
|
|
110
|
-
prev_params.copy_(
|
|
151
|
+
s_k = var.params - prev_params
|
|
152
|
+
prev_params.copy_(var.params)
|
|
111
153
|
|
|
112
|
-
assert
|
|
113
|
-
y_k =
|
|
114
|
-
prev_grad.copy_(
|
|
154
|
+
assert var.grad is not None
|
|
155
|
+
y_k = var.grad - prev_grad
|
|
156
|
+
prev_grad.copy_(var.grad)
|
|
115
157
|
|
|
116
158
|
ada_mu = (s_k.global_vector_norm() / (y_k.global_vector_norm() + eps)) * mu_mul
|
|
117
159
|
|
|
@@ -119,6 +161,6 @@ class AdaptiveMatrixMomentum(Module):
|
|
|
119
161
|
hvp_ = as_tensorlist(hvp_)
|
|
120
162
|
update.add_(prev_update - hvp_*ada_mu)
|
|
121
163
|
prev_update.set_(update * beta)
|
|
122
|
-
|
|
123
|
-
return
|
|
164
|
+
var.update = update
|
|
165
|
+
return var
|
|
124
166
|
|
|
@@ -3,11 +3,22 @@ from typing import Literal
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from ...core import Target, Transform
|
|
6
|
-
from ...utils import NumberList, TensorList
|
|
6
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
7
7
|
from .ema import EMA
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class HeavyBall(EMA):
|
|
11
|
+
"""Polyak's momentum (heavy-ball method).
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
15
|
+
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
16
|
+
debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
|
|
17
|
+
lerp (bool, optional):
|
|
18
|
+
whether to use linear interpolation, if True, this becomes exponential moving average. Defaults to False.
|
|
19
|
+
ema_init (str, optional): initial values for the EMA, "zeros" or "update".
|
|
20
|
+
target (Target, optional): target to apply EMA to. Defaults to 'update'.
|
|
21
|
+
"""
|
|
11
22
|
def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update', target: Target = 'update'):
|
|
12
23
|
super().__init__(momentum=momentum, dampening=dampening, debiased=debiased, lerp=lerp, ema_init=ema_init, target=target)
|
|
13
24
|
|
|
@@ -30,14 +41,23 @@ def nag_(
|
|
|
30
41
|
|
|
31
42
|
|
|
32
43
|
class NAG(Transform):
|
|
44
|
+
"""Nesterov accelerated gradient method (nesterov momentum).
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
48
|
+
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
49
|
+
lerp (bool, optional):
|
|
50
|
+
whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
|
|
51
|
+
target (Target, optional): target to apply EMA to. Defaults to 'update'.
|
|
52
|
+
"""
|
|
33
53
|
def __init__(self, momentum:float=0.9, dampening:float=0, lerp=False, target: Target = 'update'):
|
|
34
54
|
defaults = dict(momentum=momentum,dampening=dampening, lerp=lerp)
|
|
35
55
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
36
56
|
|
|
37
57
|
@torch.no_grad
|
|
38
|
-
def
|
|
39
|
-
velocity =
|
|
58
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
59
|
+
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
40
60
|
lerp = self.settings[params[0]]['lerp']
|
|
41
61
|
|
|
42
|
-
momentum,dampening =
|
|
62
|
+
momentum,dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
|
|
43
63
|
return nag_(TensorList(tensors), velocity_=velocity,momentum=momentum,dampening=dampening,lerp=lerp)
|
|
@@ -5,61 +5,91 @@ from typing import Literal
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from ...core import Target, Transform
|
|
8
|
-
from ...utils import TensorList, NumberList
|
|
8
|
+
from ...utils import TensorList, NumberList, unpack_states, unpack_dicts
|
|
9
9
|
|
|
10
10
|
class AccumulateSum(Transform):
|
|
11
|
+
"""Accumulates sum of all past updates.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
decay (float, optional): decays the accumulator. Defaults to 0.
|
|
15
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
16
|
+
"""
|
|
11
17
|
def __init__(self, decay: float = 0, target: Target = 'update',):
|
|
12
18
|
defaults = dict(decay=decay)
|
|
13
19
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
14
20
|
|
|
15
21
|
@torch.no_grad
|
|
16
|
-
def
|
|
17
|
-
sum =
|
|
18
|
-
decay =
|
|
19
|
-
return sum.add_(tensors).lazy_mul(
|
|
22
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
23
|
+
sum = unpack_states(states, tensors, 'sum', cls=TensorList)
|
|
24
|
+
decay = [1-s['decay'] for s in settings]
|
|
25
|
+
return sum.add_(tensors).lazy_mul(decay, clone=True)
|
|
20
26
|
|
|
21
27
|
class AccumulateMean(Transform):
|
|
28
|
+
"""Accumulates mean of all past updates.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
decay (float, optional): decays the accumulator. Defaults to 0.
|
|
32
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
33
|
+
"""
|
|
22
34
|
def __init__(self, decay: float = 0, target: Target = 'update',):
|
|
23
35
|
defaults = dict(decay=decay)
|
|
24
36
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
25
37
|
|
|
26
38
|
@torch.no_grad
|
|
27
|
-
def
|
|
39
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
28
40
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
29
|
-
mean =
|
|
30
|
-
decay =
|
|
31
|
-
return mean.add_(tensors).lazy_mul(
|
|
41
|
+
mean = unpack_states(states, tensors, 'mean', cls=TensorList)
|
|
42
|
+
decay = [1-s['decay'] for s in settings]
|
|
43
|
+
return mean.add_(tensors).lazy_mul(decay, clone=True).div_(step)
|
|
32
44
|
|
|
33
45
|
class AccumulateProduct(Transform):
|
|
46
|
+
"""Accumulates product of all past updates.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
decay (float, optional): decays the accumulator. Defaults to 0.
|
|
50
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
51
|
+
"""
|
|
34
52
|
def __init__(self, decay: float = 0, target: Target = 'update',):
|
|
35
53
|
defaults = dict(decay=decay)
|
|
36
54
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
37
55
|
|
|
38
56
|
@torch.no_grad
|
|
39
|
-
def
|
|
40
|
-
prod =
|
|
41
|
-
decay =
|
|
42
|
-
return prod.mul_(tensors).lazy_mul(
|
|
57
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
58
|
+
prod = unpack_states(states, tensors, 'prod', cls=TensorList)
|
|
59
|
+
decay = [1-s['decay'] for s in settings]
|
|
60
|
+
return prod.mul_(tensors).lazy_mul(decay, clone=True)
|
|
43
61
|
|
|
44
62
|
class AccumulateMaximum(Transform):
|
|
63
|
+
"""Accumulates maximum of all past updates.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
decay (float, optional): decays the accumulator. Defaults to 0.
|
|
67
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
68
|
+
"""
|
|
45
69
|
def __init__(self, decay: float = 0, target: Target = 'update',):
|
|
46
70
|
defaults = dict(decay=decay)
|
|
47
71
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
48
72
|
|
|
49
73
|
@torch.no_grad
|
|
50
|
-
def
|
|
51
|
-
maximum =
|
|
52
|
-
decay =
|
|
53
|
-
return maximum.maximum_(tensors).lazy_mul(
|
|
74
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
75
|
+
maximum = unpack_states(states, tensors, 'maximum', cls=TensorList)
|
|
76
|
+
decay = [1-s['decay'] for s in settings]
|
|
77
|
+
return maximum.maximum_(tensors).lazy_mul(decay, clone=True)
|
|
54
78
|
|
|
55
79
|
class AccumulateMinimum(Transform):
|
|
80
|
+
"""Accumulates minimum of all past updates.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
decay (float, optional): decays the accumulator. Defaults to 0.
|
|
84
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
85
|
+
"""
|
|
56
86
|
def __init__(self, decay: float = 0, target: Target = 'update',):
|
|
57
87
|
defaults = dict(decay=decay)
|
|
58
88
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
59
89
|
|
|
60
90
|
@torch.no_grad
|
|
61
|
-
def
|
|
62
|
-
minimum =
|
|
63
|
-
decay =
|
|
64
|
-
return minimum.minimum_(tensors).lazy_mul(
|
|
91
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
92
|
+
minimum = unpack_states(states, tensors, 'minimum', cls=TensorList)
|
|
93
|
+
decay = [1-s['decay'] for s in settings]
|
|
94
|
+
return minimum.minimum_(tensors).lazy_mul(decay, clone=True)
|
|
65
95
|
|
torchzero/modules/ops/binary.py
CHANGED
|
@@ -7,7 +7,7 @@ from typing import Any
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
from ...core import Chainable, Module, Target,
|
|
10
|
+
from ...core import Chainable, Module, Target, Var, maybe_chain
|
|
11
11
|
from ...utils import TensorList, tensorlist
|
|
12
12
|
|
|
13
13
|
|
|
@@ -26,25 +26,25 @@ class BinaryOperation(Module, ABC):
|
|
|
26
26
|
self.operands[k] = v
|
|
27
27
|
|
|
28
28
|
@abstractmethod
|
|
29
|
-
def transform(self,
|
|
29
|
+
def transform(self, var: Var, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
|
|
30
30
|
"""applies the operation to operands"""
|
|
31
31
|
raise NotImplementedError
|
|
32
32
|
|
|
33
33
|
@torch.no_grad
|
|
34
|
-
def step(self,
|
|
34
|
+
def step(self, var: Var) -> Var:
|
|
35
35
|
# pass cloned update to all module operands
|
|
36
36
|
processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()
|
|
37
37
|
|
|
38
38
|
for k,v in self.operands.items():
|
|
39
39
|
if k in self.children:
|
|
40
40
|
v: Module
|
|
41
|
-
|
|
42
|
-
processed_operands[k] =
|
|
43
|
-
|
|
41
|
+
updated_var = v.step(var.clone(clone_update=True))
|
|
42
|
+
processed_operands[k] = updated_var.get_update()
|
|
43
|
+
var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them
|
|
44
44
|
|
|
45
|
-
transformed = self.transform(
|
|
46
|
-
|
|
47
|
-
return
|
|
45
|
+
transformed = self.transform(var, update=var.get_update(), **processed_operands)
|
|
46
|
+
var.update = list(transformed)
|
|
47
|
+
return var
|
|
48
48
|
|
|
49
49
|
|
|
50
50
|
class Add(BinaryOperation):
|
|
@@ -53,9 +53,9 @@ class Add(BinaryOperation):
|
|
|
53
53
|
super().__init__(defaults, other=other)
|
|
54
54
|
|
|
55
55
|
@torch.no_grad
|
|
56
|
-
def transform(self,
|
|
57
|
-
if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.settings[
|
|
58
|
-
else: torch._foreach_add_(update, other, alpha=self.settings[
|
|
56
|
+
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
57
|
+
if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.settings[var.params[0]]['alpha'])
|
|
58
|
+
else: torch._foreach_add_(update, other, alpha=self.settings[var.params[0]]['alpha'])
|
|
59
59
|
return update
|
|
60
60
|
|
|
61
61
|
class Sub(BinaryOperation):
|
|
@@ -64,9 +64,9 @@ class Sub(BinaryOperation):
|
|
|
64
64
|
super().__init__(defaults, other=other)
|
|
65
65
|
|
|
66
66
|
@torch.no_grad
|
|
67
|
-
def transform(self,
|
|
68
|
-
if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.settings[
|
|
69
|
-
else: torch._foreach_sub_(update, other, alpha=self.settings[
|
|
67
|
+
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
68
|
+
if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.settings[var.params[0]]['alpha'])
|
|
69
|
+
else: torch._foreach_sub_(update, other, alpha=self.settings[var.params[0]]['alpha'])
|
|
70
70
|
return update
|
|
71
71
|
|
|
72
72
|
class RSub(BinaryOperation):
|
|
@@ -74,7 +74,7 @@ class RSub(BinaryOperation):
|
|
|
74
74
|
super().__init__({}, other=other)
|
|
75
75
|
|
|
76
76
|
@torch.no_grad
|
|
77
|
-
def transform(self,
|
|
77
|
+
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
78
78
|
return other - TensorList(update)
|
|
79
79
|
|
|
80
80
|
class Mul(BinaryOperation):
|
|
@@ -82,7 +82,7 @@ class Mul(BinaryOperation):
|
|
|
82
82
|
super().__init__({}, other=other)
|
|
83
83
|
|
|
84
84
|
@torch.no_grad
|
|
85
|
-
def transform(self,
|
|
85
|
+
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
86
86
|
torch._foreach_mul_(update, other)
|
|
87
87
|
return update
|
|
88
88
|
|
|
@@ -91,7 +91,7 @@ class Div(BinaryOperation):
|
|
|
91
91
|
super().__init__({}, other=other)
|
|
92
92
|
|
|
93
93
|
@torch.no_grad
|
|
94
|
-
def transform(self,
|
|
94
|
+
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
95
95
|
torch._foreach_div_(update, other)
|
|
96
96
|
return update
|
|
97
97
|
|
|
@@ -100,7 +100,7 @@ class RDiv(BinaryOperation):
|
|
|
100
100
|
super().__init__({}, other=other)
|
|
101
101
|
|
|
102
102
|
@torch.no_grad
|
|
103
|
-
def transform(self,
|
|
103
|
+
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
104
104
|
return other / TensorList(update)
|
|
105
105
|
|
|
106
106
|
class Pow(BinaryOperation):
|
|
@@ -108,7 +108,7 @@ class Pow(BinaryOperation):
|
|
|
108
108
|
super().__init__({}, exponent=exponent)
|
|
109
109
|
|
|
110
110
|
@torch.no_grad
|
|
111
|
-
def transform(self,
|
|
111
|
+
def transform(self, var, update: list[torch.Tensor], exponent: float | list[torch.Tensor]):
|
|
112
112
|
torch._foreach_pow_(update, exponent)
|
|
113
113
|
return update
|
|
114
114
|
|
|
@@ -117,7 +117,7 @@ class RPow(BinaryOperation):
|
|
|
117
117
|
super().__init__({}, other=other)
|
|
118
118
|
|
|
119
119
|
@torch.no_grad
|
|
120
|
-
def transform(self,
|
|
120
|
+
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
121
121
|
if isinstance(other, (int, float)): return torch._foreach_pow(other, update) # no in-place
|
|
122
122
|
torch._foreach_pow_(other, update)
|
|
123
123
|
return other
|
|
@@ -128,8 +128,8 @@ class Lerp(BinaryOperation):
|
|
|
128
128
|
super().__init__(defaults, end=end)
|
|
129
129
|
|
|
130
130
|
@torch.no_grad
|
|
131
|
-
def transform(self,
|
|
132
|
-
torch._foreach_lerp_(update, end, weight=self.get_settings('weight'
|
|
131
|
+
def transform(self, var, update: list[torch.Tensor], end: list[torch.Tensor]):
|
|
132
|
+
torch._foreach_lerp_(update, end, weight=self.get_settings(var.params, 'weight'))
|
|
133
133
|
return update
|
|
134
134
|
|
|
135
135
|
class CopySign(BinaryOperation):
|
|
@@ -137,7 +137,7 @@ class CopySign(BinaryOperation):
|
|
|
137
137
|
super().__init__({}, other=other)
|
|
138
138
|
|
|
139
139
|
@torch.no_grad
|
|
140
|
-
def transform(self,
|
|
140
|
+
def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
141
141
|
return [u.copysign_(o) for u, o in zip(update, other)]
|
|
142
142
|
|
|
143
143
|
class RCopySign(BinaryOperation):
|
|
@@ -145,7 +145,7 @@ class RCopySign(BinaryOperation):
|
|
|
145
145
|
super().__init__({}, other=other)
|
|
146
146
|
|
|
147
147
|
@torch.no_grad
|
|
148
|
-
def transform(self,
|
|
148
|
+
def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
149
149
|
return [o.copysign_(u) for u, o in zip(update, other)]
|
|
150
150
|
CopyMagnitude = RCopySign
|
|
151
151
|
|
|
@@ -154,7 +154,7 @@ class Clip(BinaryOperation):
|
|
|
154
154
|
super().__init__({}, min=min, max=max)
|
|
155
155
|
|
|
156
156
|
@torch.no_grad
|
|
157
|
-
def transform(self,
|
|
157
|
+
def transform(self, var, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
|
|
158
158
|
return TensorList(update).clamp_(min=min, max=max)
|
|
159
159
|
|
|
160
160
|
class MirroredClip(BinaryOperation):
|
|
@@ -163,7 +163,7 @@ class MirroredClip(BinaryOperation):
|
|
|
163
163
|
super().__init__({}, value=value)
|
|
164
164
|
|
|
165
165
|
@torch.no_grad
|
|
166
|
-
def transform(self,
|
|
166
|
+
def transform(self, var, update: list[torch.Tensor], value: float | list[torch.Tensor]):
|
|
167
167
|
min = -value if isinstance(value, (int,float)) else [-v for v in value]
|
|
168
168
|
return TensorList(update).clamp_(min=min, max=value)
|
|
169
169
|
|
|
@@ -174,8 +174,8 @@ class Graft(BinaryOperation):
|
|
|
174
174
|
super().__init__(defaults, magnitude=magnitude)
|
|
175
175
|
|
|
176
176
|
@torch.no_grad
|
|
177
|
-
def transform(self,
|
|
178
|
-
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[
|
|
177
|
+
def transform(self, var, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
|
|
178
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[var.params[0]])
|
|
179
179
|
return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
180
180
|
|
|
181
181
|
class RGraft(BinaryOperation):
|
|
@@ -186,8 +186,8 @@ class RGraft(BinaryOperation):
|
|
|
186
186
|
super().__init__(defaults, direction=direction)
|
|
187
187
|
|
|
188
188
|
@torch.no_grad
|
|
189
|
-
def transform(self,
|
|
190
|
-
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[
|
|
189
|
+
def transform(self, var, update: list[torch.Tensor], direction: list[torch.Tensor]):
|
|
190
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[var.params[0]])
|
|
191
191
|
return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
192
192
|
|
|
193
193
|
GraftToUpdate = RGraft
|
|
@@ -197,7 +197,7 @@ class Maximum(BinaryOperation):
|
|
|
197
197
|
super().__init__({}, other=other)
|
|
198
198
|
|
|
199
199
|
@torch.no_grad
|
|
200
|
-
def transform(self,
|
|
200
|
+
def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
201
201
|
torch._foreach_maximum_(update, other)
|
|
202
202
|
return update
|
|
203
203
|
|
|
@@ -206,7 +206,7 @@ class Minimum(BinaryOperation):
|
|
|
206
206
|
super().__init__({}, other=other)
|
|
207
207
|
|
|
208
208
|
@torch.no_grad
|
|
209
|
-
def transform(self,
|
|
209
|
+
def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
210
210
|
torch._foreach_minimum_(update, other)
|
|
211
211
|
return update
|
|
212
212
|
|
|
@@ -217,7 +217,7 @@ class GramSchimdt(BinaryOperation):
|
|
|
217
217
|
super().__init__({}, other=other)
|
|
218
218
|
|
|
219
219
|
@torch.no_grad
|
|
220
|
-
def transform(self,
|
|
220
|
+
def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
221
221
|
update = TensorList(update); other = TensorList(other)
|
|
222
222
|
return update - (other*update) / ((other*other) + 1e-8)
|
|
223
223
|
|
|
@@ -229,8 +229,8 @@ class Threshold(BinaryOperation):
|
|
|
229
229
|
super().__init__(defaults, threshold=threshold, value=value)
|
|
230
230
|
|
|
231
231
|
@torch.no_grad
|
|
232
|
-
def transform(self,
|
|
233
|
-
update_above = self.settings[
|
|
232
|
+
def transform(self, var, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
|
|
233
|
+
update_above = self.settings[var.params[0]]['update_above']
|
|
234
234
|
update = TensorList(update)
|
|
235
235
|
if update_above:
|
|
236
236
|
if isinstance(value, list): return update.where_(update>threshold, value)
|
torchzero/modules/ops/debug.py
CHANGED
|
@@ -10,16 +10,16 @@ class PrintUpdate(Module):
|
|
|
10
10
|
defaults = dict(text=text, print_fn=print_fn)
|
|
11
11
|
super().__init__(defaults)
|
|
12
12
|
|
|
13
|
-
def step(self,
|
|
14
|
-
self.settings[
|
|
15
|
-
return
|
|
13
|
+
def step(self, var):
|
|
14
|
+
self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{var.update}')
|
|
15
|
+
return var
|
|
16
16
|
|
|
17
17
|
class PrintShape(Module):
|
|
18
18
|
def __init__(self, text = 'shapes = ', print_fn = print):
|
|
19
19
|
defaults = dict(text=text, print_fn=print_fn)
|
|
20
20
|
super().__init__(defaults)
|
|
21
21
|
|
|
22
|
-
def step(self,
|
|
23
|
-
shapes = [u.shape for u in
|
|
24
|
-
self.settings[
|
|
25
|
-
return
|
|
22
|
+
def step(self, var):
|
|
23
|
+
shapes = [u.shape for u in var.update] if var.update is not None else None
|
|
24
|
+
self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{shapes}')
|
|
25
|
+
return var
|