torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
tests/test_objective.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import torch
|
|
3
|
+
from torchzero.core import Objective
|
|
4
|
+
from torchzero.utils.tensorlist import TensorList
|
|
5
|
+
|
|
6
|
+
@torch.no_grad
|
|
7
|
+
def test_get_loss():
|
|
8
|
+
|
|
9
|
+
# ---------------------------- test that it works ---------------------------- #
|
|
10
|
+
params = [torch.tensor(2.0, requires_grad=True)]
|
|
11
|
+
evaluated = False
|
|
12
|
+
|
|
13
|
+
def closure_1(backward=True):
|
|
14
|
+
assert not backward, 'backward = True'
|
|
15
|
+
|
|
16
|
+
# ensure closure only evaluates once
|
|
17
|
+
nonlocal evaluated
|
|
18
|
+
assert evaluated is False, 'closure was evaluated twice'
|
|
19
|
+
evaluated = True
|
|
20
|
+
|
|
21
|
+
loss = params[0]**2
|
|
22
|
+
if backward:
|
|
23
|
+
params[0].grad = None
|
|
24
|
+
loss.backward()
|
|
25
|
+
else:
|
|
26
|
+
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
27
|
+
return loss
|
|
28
|
+
|
|
29
|
+
obj = Objective(params=params, closure=closure_1, model=None, current_step=0)
|
|
30
|
+
|
|
31
|
+
assert obj.loss is None, obj.loss
|
|
32
|
+
|
|
33
|
+
assert (loss := obj.get_loss(backward=False)) == 4.0, loss
|
|
34
|
+
assert evaluated, evaluated
|
|
35
|
+
assert loss is obj.loss
|
|
36
|
+
assert obj.loss == 4.0
|
|
37
|
+
assert obj.loss_approx == 4.0
|
|
38
|
+
assert obj.grads is None, obj.grads
|
|
39
|
+
|
|
40
|
+
# reevaluate, which should just return already evaluated loss
|
|
41
|
+
assert (loss := obj.get_loss(backward=False)) == 4.0, loss
|
|
42
|
+
assert obj.grads is None, obj.grads
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# ----------------------- test that backward=True works ---------------------- #
|
|
46
|
+
params = [torch.tensor(3.0, requires_grad=True)]
|
|
47
|
+
evaluated = False
|
|
48
|
+
|
|
49
|
+
def closure_2(backward=True):
|
|
50
|
+
# ensure closure only evaluates once
|
|
51
|
+
nonlocal evaluated
|
|
52
|
+
assert evaluated is False, 'closure was evaluated twice'
|
|
53
|
+
evaluated = True
|
|
54
|
+
|
|
55
|
+
loss = params[0] * 2
|
|
56
|
+
if backward:
|
|
57
|
+
assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
|
|
58
|
+
params[0].grad = None
|
|
59
|
+
loss.backward()
|
|
60
|
+
else:
|
|
61
|
+
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
62
|
+
return loss
|
|
63
|
+
|
|
64
|
+
obj = Objective(params=params, closure=closure_2, model=None, current_step=0)
|
|
65
|
+
assert obj.grads is None, obj.grads
|
|
66
|
+
assert (loss := obj.get_loss(backward=True)) == 6.0, loss
|
|
67
|
+
assert obj.grads is not None
|
|
68
|
+
assert obj.grads[0] == 2.0, obj.grads
|
|
69
|
+
|
|
70
|
+
# reevaluate, which should just return already evaluated loss
|
|
71
|
+
assert (loss := obj.get_loss(backward=True)) == 6.0, loss
|
|
72
|
+
assert obj.grads[0] == 2.0, obj.grads
|
|
73
|
+
|
|
74
|
+
# get grad, which should just return already evaluated grad
|
|
75
|
+
assert (grad := obj.get_grads())[0] == 2.0, grad
|
|
76
|
+
assert grad is obj.grads, grad
|
|
77
|
+
|
|
78
|
+
# get update, which should create and return cloned grad
|
|
79
|
+
assert obj.updates is None
|
|
80
|
+
assert (update := obj.get_updates())[0] == 2.0, update
|
|
81
|
+
assert update is obj.updates
|
|
82
|
+
assert update is not obj.grads
|
|
83
|
+
assert obj.grads is not None
|
|
84
|
+
assert update[0] == obj.grads[0]
|
|
85
|
+
|
|
86
|
+
@torch.no_grad
|
|
87
|
+
def test_get_grad():
|
|
88
|
+
params = [torch.tensor(2.0, requires_grad=True)]
|
|
89
|
+
evaluated = False
|
|
90
|
+
|
|
91
|
+
def closure(backward=True):
|
|
92
|
+
# ensure closure only evaluates once
|
|
93
|
+
nonlocal evaluated
|
|
94
|
+
assert evaluated is False, 'closure was evaluated twice'
|
|
95
|
+
evaluated = True
|
|
96
|
+
|
|
97
|
+
loss = params[0]**2
|
|
98
|
+
if backward:
|
|
99
|
+
assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
|
|
100
|
+
params[0].grad = None
|
|
101
|
+
loss.backward()
|
|
102
|
+
else:
|
|
103
|
+
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
104
|
+
return loss
|
|
105
|
+
|
|
106
|
+
obj = Objective(params=params, closure=closure, model=None, current_step=0)
|
|
107
|
+
assert (grad := obj.get_grads())[0] == 4.0, grad
|
|
108
|
+
assert grad is obj.grads
|
|
109
|
+
|
|
110
|
+
assert obj.loss == 4.0
|
|
111
|
+
assert (loss := obj.get_loss(backward=False)) == 4.0, loss
|
|
112
|
+
assert (loss := obj.get_loss(backward=True)) == 4.0, loss
|
|
113
|
+
assert obj.loss_approx == 4.0
|
|
114
|
+
|
|
115
|
+
assert obj.updates is None, obj.updates
|
|
116
|
+
assert (update := obj.get_updates())[0] == 4.0, update
|
|
117
|
+
|
|
118
|
+
@torch.no_grad
|
|
119
|
+
def test_get_update():
|
|
120
|
+
params = [torch.tensor(2.0, requires_grad=True)]
|
|
121
|
+
evaluated = False
|
|
122
|
+
|
|
123
|
+
def closure(backward=True):
|
|
124
|
+
# ensure closure only evaluates once
|
|
125
|
+
nonlocal evaluated
|
|
126
|
+
assert evaluated is False, 'closure was evaluated twice'
|
|
127
|
+
evaluated = True
|
|
128
|
+
|
|
129
|
+
loss = params[0]**2
|
|
130
|
+
if backward:
|
|
131
|
+
assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
|
|
132
|
+
params[0].grad = None
|
|
133
|
+
loss.backward()
|
|
134
|
+
else:
|
|
135
|
+
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
136
|
+
return loss
|
|
137
|
+
|
|
138
|
+
obj = Objective(params=params, closure=closure, model=None, current_step=0)
|
|
139
|
+
assert obj.updates is None, obj.updates
|
|
140
|
+
assert (update := obj.get_updates())[0] == 4.0, update
|
|
141
|
+
assert update is obj.updates
|
|
142
|
+
|
|
143
|
+
assert (grad := obj.get_grads())[0] == 4.0, grad
|
|
144
|
+
assert grad is obj.grads
|
|
145
|
+
assert grad is not update
|
|
146
|
+
|
|
147
|
+
assert obj.loss == 4.0
|
|
148
|
+
assert (loss := obj.get_loss(backward=False)) == 4.0, loss
|
|
149
|
+
assert (loss := obj.get_loss(backward=True)) == 4.0, loss
|
|
150
|
+
assert obj.loss_approx == 4.0
|
|
151
|
+
|
|
152
|
+
assert (update := obj.get_updates())[0] == 4.0, update
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _assert_objectives_are_same_(o1: Objective, o2: Objective, clone_update: bool):
|
|
156
|
+
for k,v in o1.__dict__.items():
|
|
157
|
+
if not k.startswith('__'):
|
|
158
|
+
# if k == 'post_step_hooks': continue
|
|
159
|
+
if k == 'storage': continue
|
|
160
|
+
elif k == 'updates' and clone_update:
|
|
161
|
+
if o1.updates is None or o2.updates is None:
|
|
162
|
+
assert o1.updates is None and o2.updates is None, f'`{k}` attribute is not the same, {o1.updates = }, {o2.updates = }'
|
|
163
|
+
else:
|
|
164
|
+
assert (TensorList(o1.updates) == TensorList(o2.updates)).global_all()
|
|
165
|
+
assert o1.updates is not o2.updates
|
|
166
|
+
elif k == 'params':
|
|
167
|
+
for p1, p2 in zip(o1.params, o2.params):
|
|
168
|
+
assert p1.untyped_storage() == p2.untyped_storage()
|
|
169
|
+
else:
|
|
170
|
+
assert getattr(o2, k) is v, f'`{k}` attribute is not the same, {getattr(o1, k) = }, {getattr(o2, k) = }'
|
|
171
|
+
|
|
172
|
+
def test_var_clone():
|
|
173
|
+
model = torch.nn.Sequential(torch.nn.Linear(2,2), torch.nn.Linear(2,4))
|
|
174
|
+
def closure(backward): return 1
|
|
175
|
+
obj = Objective(params=list(model.parameters()), closure=closure, model=model, current_step=0)
|
|
176
|
+
|
|
177
|
+
_assert_objectives_are_same_(obj, obj.clone(clone_updates=False), clone_update=False)
|
|
178
|
+
_assert_objectives_are_same_(obj, obj.clone(clone_updates=True), clone_update=True)
|
|
179
|
+
|
|
180
|
+
obj.grads = TensorList(torch.randn(5))
|
|
181
|
+
_assert_objectives_are_same_(obj, obj.clone(clone_updates=False), clone_update=False)
|
|
182
|
+
_assert_objectives_are_same_(obj, obj.clone(clone_updates=True), clone_update=True)
|
|
183
|
+
|
|
184
|
+
obj.updates = TensorList(torch.randn(5) * 2)
|
|
185
|
+
obj.loss = torch.randn(1)
|
|
186
|
+
obj.loss_approx = obj.loss
|
|
187
|
+
_assert_objectives_are_same_(obj, obj.clone(clone_updates=False), clone_update=False)
|
|
188
|
+
_assert_objectives_are_same_(obj, obj.clone(clone_updates=True), clone_update=True)
|
tests/test_opts.py
CHANGED
|
@@ -4,15 +4,23 @@ Sanity tests to make sure everything works.
|
|
|
4
4
|
This will show major convergence regressions, but that is not the main purpose. Mainly this makes sure modules
|
|
5
5
|
don't error or become unhinged with different parameter shapes.
|
|
6
6
|
"""
|
|
7
|
+
import random
|
|
7
8
|
from collections.abc import Callable
|
|
8
9
|
from functools import partial
|
|
9
10
|
|
|
11
|
+
import numpy as np
|
|
10
12
|
import pytest
|
|
11
13
|
import torch
|
|
14
|
+
|
|
12
15
|
import torchzero as tz
|
|
13
16
|
|
|
14
17
|
PRINT = False # set to true in nbs
|
|
15
18
|
|
|
19
|
+
# seed
|
|
20
|
+
torch.manual_seed(0)
|
|
21
|
+
np.random.seed(0)
|
|
22
|
+
random.seed(0)
|
|
23
|
+
|
|
16
24
|
def _booth(x, y):
|
|
17
25
|
return (x + 2 * y - 7) ** 2 + (2 * x + y - 5) ** 2
|
|
18
26
|
|
|
@@ -51,7 +59,7 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
|
|
|
51
59
|
losses = []
|
|
52
60
|
for i in range(steps):
|
|
53
61
|
if clear and i == steps//2:
|
|
54
|
-
for m in opt.
|
|
62
|
+
for m in opt.flat_modules: m.reset() # clear on middle step to see if there are any issues with it
|
|
55
63
|
|
|
56
64
|
if use_closure:
|
|
57
65
|
def closure(backward=True):
|
|
@@ -283,8 +291,8 @@ ClipNormGrowth_additive = Run(
|
|
|
283
291
|
sphere_steps=10, sphere_loss=10,
|
|
284
292
|
)
|
|
285
293
|
ClipNormGrowth_global = Run(
|
|
286
|
-
func_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(
|
|
287
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(
|
|
294
|
+
func_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(tensorwise=False), tz.m.LR(0.1)),
|
|
295
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(tensorwise=False), tz.m.LR(0.1)),
|
|
288
296
|
needs_closure=False,
|
|
289
297
|
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
290
298
|
sphere_steps=10, sphere_loss=10,
|
|
@@ -340,56 +348,56 @@ RandomizedFDM_central2 = Run(
|
|
|
340
348
|
sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(seed=0), tz.m.LR(0.001)),
|
|
341
349
|
needs_closure=True,
|
|
342
350
|
func='booth', steps=50, loss=10, merge_invariant=True,
|
|
343
|
-
sphere_steps=
|
|
351
|
+
sphere_steps=200, sphere_loss=420,
|
|
344
352
|
)
|
|
345
353
|
RandomizedFDM_forward2 = Run(
|
|
346
354
|
func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward2', seed=0), tz.m.LR(0.01)),
|
|
347
355
|
sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward2', seed=0), tz.m.LR(0.001)),
|
|
348
356
|
needs_closure=True,
|
|
349
357
|
func='booth', steps=50, loss=10, merge_invariant=True,
|
|
350
|
-
sphere_steps=
|
|
358
|
+
sphere_steps=200, sphere_loss=420,
|
|
351
359
|
)
|
|
352
360
|
RandomizedFDM_backward2 = Run(
|
|
353
361
|
func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward2', seed=0), tz.m.LR(0.01)),
|
|
354
362
|
sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward2', seed=0), tz.m.LR(0.001)),
|
|
355
363
|
needs_closure=True,
|
|
356
364
|
func='booth', steps=50, loss=10, merge_invariant=True,
|
|
357
|
-
sphere_steps=
|
|
365
|
+
sphere_steps=200, sphere_loss=420,
|
|
358
366
|
)
|
|
359
367
|
RandomizedFDM_forward3 = Run(
|
|
360
368
|
func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward3', seed=0), tz.m.LR(0.01)),
|
|
361
369
|
sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward3', seed=0), tz.m.LR(0.001)),
|
|
362
370
|
needs_closure=True,
|
|
363
371
|
func='booth', steps=50, loss=10, merge_invariant=True,
|
|
364
|
-
sphere_steps=
|
|
372
|
+
sphere_steps=200, sphere_loss=420,
|
|
365
373
|
)
|
|
366
374
|
RandomizedFDM_backward3 = Run(
|
|
367
375
|
func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward3', seed=0), tz.m.LR(0.01)),
|
|
368
376
|
sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward3', seed=0), tz.m.LR(0.001)),
|
|
369
377
|
needs_closure=True,
|
|
370
378
|
func='booth', steps=50, loss=10, merge_invariant=True,
|
|
371
|
-
sphere_steps=
|
|
379
|
+
sphere_steps=200, sphere_loss=420,
|
|
372
380
|
)
|
|
373
381
|
RandomizedFDM_central4 = Run(
|
|
374
382
|
func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='central4', seed=0), tz.m.LR(0.01)),
|
|
375
383
|
sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='central4', seed=0), tz.m.LR(0.001)),
|
|
376
384
|
needs_closure=True,
|
|
377
385
|
func='booth', steps=50, loss=10, merge_invariant=True,
|
|
378
|
-
sphere_steps=
|
|
386
|
+
sphere_steps=200, sphere_loss=420,
|
|
379
387
|
)
|
|
380
388
|
RandomizedFDM_forward4 = Run(
|
|
381
389
|
func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.01)),
|
|
382
390
|
sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.001)),
|
|
383
391
|
needs_closure=True,
|
|
384
392
|
func='booth', steps=50, loss=10, merge_invariant=True,
|
|
385
|
-
sphere_steps=
|
|
393
|
+
sphere_steps=200, sphere_loss=420,
|
|
386
394
|
)
|
|
387
395
|
RandomizedFDM_forward5 = Run(
|
|
388
396
|
func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.01)),
|
|
389
397
|
sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.001)),
|
|
390
398
|
needs_closure=True,
|
|
391
399
|
func='booth', steps=50, loss=10, merge_invariant=True,
|
|
392
|
-
sphere_steps=
|
|
400
|
+
sphere_steps=200, sphere_loss=420,
|
|
393
401
|
)
|
|
394
402
|
|
|
395
403
|
|
|
@@ -427,35 +435,35 @@ ForwardGradient = Run(
|
|
|
427
435
|
sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0), tz.m.LR(0.001)),
|
|
428
436
|
needs_closure=True,
|
|
429
437
|
func='booth', steps=50, loss=40, merge_invariant=True,
|
|
430
|
-
sphere_steps=
|
|
438
|
+
sphere_steps=200, sphere_loss=450,
|
|
431
439
|
)
|
|
432
440
|
ForwardGradient_forward = Run(
|
|
433
441
|
func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='forward'), tz.m.LR(0.01)),
|
|
434
442
|
sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='forward'), tz.m.LR(0.001)),
|
|
435
443
|
needs_closure=True,
|
|
436
444
|
func='booth', steps=50, loss=40, merge_invariant=True,
|
|
437
|
-
sphere_steps=
|
|
445
|
+
sphere_steps=200, sphere_loss=450,
|
|
438
446
|
)
|
|
439
447
|
ForwardGradient_central = Run(
|
|
440
448
|
func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='central'), tz.m.LR(0.01)),
|
|
441
449
|
sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='central'), tz.m.LR(0.001)),
|
|
442
450
|
needs_closure=True,
|
|
443
451
|
func='booth', steps=50, loss=40, merge_invariant=True,
|
|
444
|
-
sphere_steps=
|
|
452
|
+
sphere_steps=200, sphere_loss=450,
|
|
445
453
|
)
|
|
446
454
|
ForwardGradient_4samples = Run(
|
|
447
455
|
func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0), tz.m.LR(0.1)),
|
|
448
456
|
sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0), tz.m.LR(0.001)),
|
|
449
457
|
needs_closure=True,
|
|
450
458
|
func='booth', steps=50, loss=0.1, merge_invariant=True,
|
|
451
|
-
sphere_steps=100, sphere_loss=
|
|
459
|
+
sphere_steps=100, sphere_loss=420,
|
|
452
460
|
)
|
|
453
461
|
ForwardGradient_4samples_no_pre_generate = Run(
|
|
454
462
|
func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0, pre_generate=False), tz.m.LR(0.1)),
|
|
455
463
|
sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0, pre_generate=False), tz.m.LR(0.001)),
|
|
456
464
|
needs_closure=True,
|
|
457
465
|
func='booth', steps=50, loss=0.1, merge_invariant=True,
|
|
458
|
-
sphere_steps=100, sphere_loss=
|
|
466
|
+
sphere_steps=100, sphere_loss=420,
|
|
459
467
|
)
|
|
460
468
|
|
|
461
469
|
# ------------------------- line_search/backtracking ------------------------- #
|
|
@@ -598,15 +606,15 @@ ScaleModulesByCosineSimilarity = Run(
|
|
|
598
606
|
|
|
599
607
|
# ------------------------- momentum/matrix_momentum ------------------------- #
|
|
600
608
|
MatrixMomentum_forward = Run(
|
|
601
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='
|
|
602
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='
|
|
609
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='fd_forward'),),
|
|
610
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_forward')),
|
|
603
611
|
needs_closure=True,
|
|
604
612
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
605
613
|
sphere_steps=10, sphere_loss=0.01,
|
|
606
614
|
)
|
|
607
615
|
MatrixMomentum_forward = Run(
|
|
608
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='
|
|
609
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='
|
|
616
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='fd_central')),
|
|
617
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_central')),
|
|
610
618
|
needs_closure=True,
|
|
611
619
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
612
620
|
sphere_steps=10, sphere_loss=0.01,
|
|
@@ -620,15 +628,15 @@ MatrixMomentum_forward = Run(
|
|
|
620
628
|
)
|
|
621
629
|
|
|
622
630
|
AdaptiveMatrixMomentum_forward = Run(
|
|
623
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='
|
|
624
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='
|
|
631
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_forward', adaptive=True)),
|
|
632
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_forward', adaptive=True)),
|
|
625
633
|
needs_closure=True,
|
|
626
634
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
627
635
|
sphere_steps=10, sphere_loss=0.05,
|
|
628
636
|
)
|
|
629
637
|
AdaptiveMatrixMomentum_central = Run(
|
|
630
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='
|
|
631
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='
|
|
638
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_central', adaptive=True)),
|
|
639
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_central', adaptive=True)),
|
|
632
640
|
needs_closure=True,
|
|
633
641
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
634
642
|
sphere_steps=10, sphere_loss=0.05,
|
|
@@ -642,15 +650,15 @@ AdaptiveMatrixMomentum_autograd = Run(
|
|
|
642
650
|
)
|
|
643
651
|
|
|
644
652
|
StochasticAdaptiveMatrixMomentum_forward = Run(
|
|
645
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='
|
|
646
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='
|
|
653
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_forward', adaptive=True, adapt_freq=1)),
|
|
654
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_forward', adaptive=True, adapt_freq=1)),
|
|
647
655
|
needs_closure=True,
|
|
648
656
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
649
657
|
sphere_steps=10, sphere_loss=0.05,
|
|
650
658
|
)
|
|
651
659
|
StochasticAdaptiveMatrixMomentum_central = Run(
|
|
652
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='
|
|
653
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='
|
|
660
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_central', adaptive=True, adapt_freq=1)),
|
|
661
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_central', adaptive=True, adapt_freq=1)),
|
|
654
662
|
needs_closure=True,
|
|
655
663
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
656
664
|
sphere_steps=10, sphere_loss=0.05,
|
|
@@ -720,10 +728,11 @@ Adam = Run(
|
|
|
720
728
|
# ------------------------------ optimizers/soap ----------------------------- #
|
|
721
729
|
SOAP = Run(
|
|
722
730
|
func_opt=lambda p: tz.Modular(p, tz.m.SOAP(), tz.m.LR(0.4)),
|
|
723
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.SOAP(), tz.m.LR(1)),
|
|
731
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.SOAP(precond_freq=1), tz.m.LR(1)),
|
|
724
732
|
needs_closure=False,
|
|
733
|
+
# merge and unmerge lrs are very different so need to test convergence separately somewhere
|
|
725
734
|
func='rosen', steps=50, loss=4, merge_invariant=False,
|
|
726
|
-
sphere_steps=20, sphere_loss=25,
|
|
735
|
+
sphere_steps=20, sphere_loss=25,
|
|
727
736
|
)
|
|
728
737
|
# ------------------------------ optimizers/lion ----------------------------- #
|
|
729
738
|
Lion = Run(
|
|
@@ -735,11 +744,12 @@ Lion = Run(
|
|
|
735
744
|
)
|
|
736
745
|
# ---------------------------- optimizers/shampoo ---------------------------- #
|
|
737
746
|
Shampoo = Run(
|
|
738
|
-
func_opt=lambda p: tz.Modular(p, tz.m.
|
|
739
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.
|
|
747
|
+
func_opt=lambda p: tz.Modular(p, tz.m.Graft(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(4)),
|
|
748
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.Graft(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.1)),
|
|
740
749
|
needs_closure=False,
|
|
750
|
+
# merge and unmerge lrs are very different so need to test convergence separately somewhere
|
|
741
751
|
func='booth', steps=50, loss=0.02, merge_invariant=False,
|
|
742
|
-
sphere_steps=20, sphere_loss=1,
|
|
752
|
+
sphere_steps=20, sphere_loss=1,
|
|
743
753
|
)
|
|
744
754
|
|
|
745
755
|
# ------------------------- quasi_newton/quasi_newton ------------------------ #
|
|
@@ -755,6 +765,7 @@ SR1 = Run(
|
|
|
755
765
|
sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(scale_first=True), tz.m.StrongWolfe(fallback=False)),
|
|
756
766
|
needs_closure=True,
|
|
757
767
|
func='rosen', steps=50, loss=1e-12, merge_invariant=True,
|
|
768
|
+
# this reaches 1e-13 on github so don't change to 0
|
|
758
769
|
sphere_steps=10, sphere_loss=0,
|
|
759
770
|
)
|
|
760
771
|
SSVM = Run(
|
|
@@ -806,7 +817,7 @@ NewtonCG = Run(
|
|
|
806
817
|
func_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
|
|
807
818
|
sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
|
|
808
819
|
needs_closure=True,
|
|
809
|
-
func='rosen', steps=20, loss=1e-
|
|
820
|
+
func='rosen', steps=20, loss=1e-10, merge_invariant=True,
|
|
810
821
|
sphere_steps=2, sphere_loss=3e-4,
|
|
811
822
|
)
|
|
812
823
|
|
|
@@ -872,8 +883,8 @@ SophiaH = Run(
|
|
|
872
883
|
|
|
873
884
|
# -------------------------- higher_order ------------------------- #
|
|
874
885
|
HigherOrderNewton = Run(
|
|
875
|
-
func_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(trust_method=None)),
|
|
876
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(2, trust_method=None)),
|
|
886
|
+
func_opt=lambda p: tz.Modular(p, tz.m.experimental.HigherOrderNewton(trust_method=None)),
|
|
887
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.experimental.HigherOrderNewton(2, trust_method=None)),
|
|
877
888
|
needs_closure=True,
|
|
878
889
|
func='rosen', steps=1, loss=2e-10, merge_invariant=True,
|
|
879
890
|
sphere_steps=1, sphere_loss=1e-10,
|
tests/test_tensorlist.py
CHANGED
|
@@ -1567,13 +1567,6 @@ def test_where(simple_tl: TensorList):
|
|
|
1567
1567
|
assert_tl_allclose(result_module, expected_tl)
|
|
1568
1568
|
|
|
1569
1569
|
|
|
1570
|
-
# Test inplace where_ (needs TensorList other)
|
|
1571
|
-
tl_copy = simple_tl.clone()
|
|
1572
|
-
result_inplace = tl_copy.where_(condition_tl, other_tl)
|
|
1573
|
-
assert result_inplace is tl_copy
|
|
1574
|
-
assert_tl_allclose(tl_copy, expected_tl)
|
|
1575
|
-
|
|
1576
|
-
|
|
1577
1570
|
def test_masked_fill(simple_tl: TensorList):
|
|
1578
1571
|
mask_tl = simple_tl.lt(0)
|
|
1579
1572
|
fill_value_scalar = 99.0
|
|
@@ -1600,7 +1593,6 @@ def test_select_set_(simple_tl: TensorList):
|
|
|
1600
1593
|
mask_tl = simple_tl.gt(0.5)
|
|
1601
1594
|
value_scalar = -1.0
|
|
1602
1595
|
value_list_scalar = [-1.0, -2.0, -3.0]
|
|
1603
|
-
value_tl = simple_tl.clone().mul_(0.1)
|
|
1604
1596
|
|
|
1605
1597
|
# Set with scalar value
|
|
1606
1598
|
tl_copy_scalar = simple_tl.clone()
|
tests/test_utils_optimizer.py
CHANGED
torchzero/__init__.py
CHANGED
torchzero/core/__init__.py
CHANGED
|
@@ -1,2 +1,8 @@
|
|
|
1
|
-
from .
|
|
2
|
-
from .
|
|
1
|
+
from .transform import TensorTransform, Transform
|
|
2
|
+
from .module import Chainable, Module
|
|
3
|
+
from .objective import DerivativesMethod, HessianMethod, HVPMethod, Objective
|
|
4
|
+
|
|
5
|
+
# order is important to avoid circular imports
|
|
6
|
+
from .modular import Modular
|
|
7
|
+
from .functional import apply, step, step_tensors, update
|
|
8
|
+
from .chain import Chain, maybe_chain
|
torchzero/core/chain.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
|
|
3
|
+
from ..utils.python_tools import flatten
|
|
4
|
+
from .module import Module, Chainable
|
|
5
|
+
from .functional import _chain_step
|
|
6
|
+
|
|
7
|
+
class Chain(Module):
|
|
8
|
+
"""Chain modules, mostly used internally"""
|
|
9
|
+
def __init__(self, *modules: Module | Iterable[Module]):
|
|
10
|
+
super().__init__()
|
|
11
|
+
flat_modules: list[Module] = flatten(modules)
|
|
12
|
+
for i, module in enumerate(flat_modules):
|
|
13
|
+
self.set_child(f'module_{i}', module)
|
|
14
|
+
|
|
15
|
+
def update(self, objective):
|
|
16
|
+
if len(self.children) > 1:
|
|
17
|
+
raise RuntimeError("can't call `update` on Chain with more than one child, as `update` and `apply` have to be called sequentially. Use the `step` method instead of update-apply.")
|
|
18
|
+
|
|
19
|
+
if len(self.children) == 0: return
|
|
20
|
+
return self.children['module_0'].update(objective)
|
|
21
|
+
|
|
22
|
+
def apply(self, objective):
|
|
23
|
+
if len(self.children) > 1:
|
|
24
|
+
raise RuntimeError("can't call `update` on Chain with more than one child, as `update` and `apply` have to be called sequentially. Use the `step` method instead of update-apply.")
|
|
25
|
+
|
|
26
|
+
if len(self.children) == 0: return objective
|
|
27
|
+
return self.children['module_0'].apply(objective)
|
|
28
|
+
|
|
29
|
+
def step(self, objective):
|
|
30
|
+
children = [self.children[f'module_{i}'] for i in range(len(self.children))]
|
|
31
|
+
return _chain_step(objective, children)
|
|
32
|
+
|
|
33
|
+
def __repr__(self):
|
|
34
|
+
s = self.__class__.__name__
|
|
35
|
+
if self.children:
|
|
36
|
+
if s == 'Chain': s = 'C' # to shorten it
|
|
37
|
+
s = f'{s}({", ".join(str(m) for m in self.children.values())})'
|
|
38
|
+
return s
|
|
39
|
+
|
|
40
|
+
def maybe_chain(*modules: Chainable) -> Module:
|
|
41
|
+
"""Returns a single module directly if only one is provided, otherwise wraps them in a ``Chain``."""
|
|
42
|
+
flat_modules: list[Module] = flatten(modules)
|
|
43
|
+
if len(flat_modules) == 1:
|
|
44
|
+
return flat_modules[0]
|
|
45
|
+
return Chain(*flat_modules)
|
|
46
|
+
|
|
47
|
+
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from collections.abc import Mapping, Sequence, Iterable, Callable
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from .objective import Objective
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from .module import Module
|
|
10
|
+
from .transform import Transform
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def update(
|
|
15
|
+
objective: "Objective",
|
|
16
|
+
module: "Transform",
|
|
17
|
+
states: list[dict[str, Any]] | None = None,
|
|
18
|
+
settings: Sequence[Mapping[str, Any]] | None = None,
|
|
19
|
+
) -> None:
|
|
20
|
+
if states is None:
|
|
21
|
+
assert settings is None
|
|
22
|
+
module.update(objective)
|
|
23
|
+
|
|
24
|
+
else:
|
|
25
|
+
assert settings is not None
|
|
26
|
+
module.update_states(objective, states, settings)
|
|
27
|
+
|
|
28
|
+
def apply(
|
|
29
|
+
objective: "Objective",
|
|
30
|
+
module: "Transform",
|
|
31
|
+
states: list[dict[str, Any]] | None = None,
|
|
32
|
+
settings: Sequence[Mapping[str, Any]] | None = None,
|
|
33
|
+
) -> "Objective":
|
|
34
|
+
if states is None:
|
|
35
|
+
assert settings is None
|
|
36
|
+
return module.apply(objective)
|
|
37
|
+
|
|
38
|
+
else:
|
|
39
|
+
assert settings is not None
|
|
40
|
+
return module.apply_states(objective, states, settings)
|
|
41
|
+
|
|
42
|
+
def _chain_step(objective: "Objective", modules: "Sequence[Module]"):
|
|
43
|
+
"""steps with ``modules`` and returns updated objective, this is used within ``step`` and within ``Chain.step``"""
|
|
44
|
+
# step
|
|
45
|
+
for i, module in enumerate(modules):
|
|
46
|
+
if i!=0: objective = objective.clone(clone_updates=False)
|
|
47
|
+
|
|
48
|
+
objective = module.step(objective)
|
|
49
|
+
if objective.stop: break
|
|
50
|
+
|
|
51
|
+
return objective
|
|
52
|
+
|
|
53
|
+
def step(objective: "Objective", modules: "Module | Sequence[Module]"):
|
|
54
|
+
"""doesn't apply hooks!"""
|
|
55
|
+
if not isinstance(modules, Sequence):
|
|
56
|
+
modules = (modules, )
|
|
57
|
+
|
|
58
|
+
if len(modules) == 0:
|
|
59
|
+
raise RuntimeError("`modules` is an empty sequence")
|
|
60
|
+
|
|
61
|
+
# if closure is None, assume backward has been called and gather grads
|
|
62
|
+
if objective.closure is None:
|
|
63
|
+
objective.grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in objective.params]
|
|
64
|
+
|
|
65
|
+
# step and return
|
|
66
|
+
return _chain_step(objective, modules)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def step_tensors(
|
|
70
|
+
modules: "Module | Sequence[Module]",
|
|
71
|
+
tensors: Sequence[torch.Tensor],
|
|
72
|
+
params: Iterable[torch.Tensor] | None = None,
|
|
73
|
+
grads: Sequence[torch.Tensor] | None = None,
|
|
74
|
+
loss: torch.Tensor | None = None,
|
|
75
|
+
closure: Callable | None = None,
|
|
76
|
+
objective: "Objective | None" = None
|
|
77
|
+
) -> list[torch.Tensor]:
|
|
78
|
+
if objective is not None:
|
|
79
|
+
if any(i is not None for i in (params, grads, loss, closure)):
|
|
80
|
+
raise RuntimeError("Specify either `objective` or `(params, grads, loss, closure)`")
|
|
81
|
+
|
|
82
|
+
if not isinstance(modules, Sequence):
|
|
83
|
+
modules = (modules, )
|
|
84
|
+
|
|
85
|
+
# make fake params if they are only used for shapes
|
|
86
|
+
if params is None:
|
|
87
|
+
params = [t.view_as(t).requires_grad_() for t in tensors]
|
|
88
|
+
|
|
89
|
+
# create objective
|
|
90
|
+
if objective is None:
|
|
91
|
+
objective = Objective(params=params, loss=loss, closure=closure)
|
|
92
|
+
|
|
93
|
+
if grads is not None:
|
|
94
|
+
objective.grads = list(grads)
|
|
95
|
+
|
|
96
|
+
objective.updates = list(tensors)
|
|
97
|
+
|
|
98
|
+
# step with modules
|
|
99
|
+
# this won't update parameters in-place because objective.Modular is None
|
|
100
|
+
objective = _chain_step(objective, modules)
|
|
101
|
+
|
|
102
|
+
# return updates
|
|
103
|
+
return objective.get_updates()
|