torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
tests/test_objective.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import torch
|
|
3
|
+
from torchzero.core import Objective
|
|
4
|
+
from torchzero.utils.tensorlist import TensorList
|
|
5
|
+
|
|
6
|
+
@torch.no_grad
|
|
7
|
+
def test_get_loss():
|
|
8
|
+
|
|
9
|
+
# ---------------------------- test that it works ---------------------------- #
|
|
10
|
+
params = [torch.tensor(2.0, requires_grad=True)]
|
|
11
|
+
evaluated = False
|
|
12
|
+
|
|
13
|
+
def closure_1(backward=True):
|
|
14
|
+
assert not backward, 'backward = True'
|
|
15
|
+
|
|
16
|
+
# ensure closure only evaluates once
|
|
17
|
+
nonlocal evaluated
|
|
18
|
+
assert evaluated is False, 'closure was evaluated twice'
|
|
19
|
+
evaluated = True
|
|
20
|
+
|
|
21
|
+
loss = params[0]**2
|
|
22
|
+
if backward:
|
|
23
|
+
params[0].grad = None
|
|
24
|
+
loss.backward()
|
|
25
|
+
else:
|
|
26
|
+
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
27
|
+
return loss
|
|
28
|
+
|
|
29
|
+
obj = Objective(params=params, closure=closure_1, model=None, current_step=0)
|
|
30
|
+
|
|
31
|
+
assert obj.loss is None, obj.loss
|
|
32
|
+
|
|
33
|
+
assert (loss := obj.get_loss(backward=False)) == 4.0, loss
|
|
34
|
+
assert evaluated, evaluated
|
|
35
|
+
assert loss is obj.loss
|
|
36
|
+
assert obj.loss == 4.0
|
|
37
|
+
assert obj.loss_approx == 4.0
|
|
38
|
+
assert obj.grads is None, obj.grads
|
|
39
|
+
|
|
40
|
+
# reevaluate, which should just return already evaluated loss
|
|
41
|
+
assert (loss := obj.get_loss(backward=False)) == 4.0, loss
|
|
42
|
+
assert obj.grads is None, obj.grads
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# ----------------------- test that backward=True works ---------------------- #
|
|
46
|
+
params = [torch.tensor(3.0, requires_grad=True)]
|
|
47
|
+
evaluated = False
|
|
48
|
+
|
|
49
|
+
def closure_2(backward=True):
|
|
50
|
+
# ensure closure only evaluates once
|
|
51
|
+
nonlocal evaluated
|
|
52
|
+
assert evaluated is False, 'closure was evaluated twice'
|
|
53
|
+
evaluated = True
|
|
54
|
+
|
|
55
|
+
loss = params[0] * 2
|
|
56
|
+
if backward:
|
|
57
|
+
assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
|
|
58
|
+
params[0].grad = None
|
|
59
|
+
loss.backward()
|
|
60
|
+
else:
|
|
61
|
+
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
62
|
+
return loss
|
|
63
|
+
|
|
64
|
+
obj = Objective(params=params, closure=closure_2, model=None, current_step=0)
|
|
65
|
+
assert obj.grads is None, obj.grads
|
|
66
|
+
assert (loss := obj.get_loss(backward=True)) == 6.0, loss
|
|
67
|
+
assert obj.grads is not None
|
|
68
|
+
assert obj.grads[0] == 2.0, obj.grads
|
|
69
|
+
|
|
70
|
+
# reevaluate, which should just return already evaluated loss
|
|
71
|
+
assert (loss := obj.get_loss(backward=True)) == 6.0, loss
|
|
72
|
+
assert obj.grads[0] == 2.0, obj.grads
|
|
73
|
+
|
|
74
|
+
# get grad, which should just return already evaluated grad
|
|
75
|
+
assert (grad := obj.get_grads())[0] == 2.0, grad
|
|
76
|
+
assert grad is obj.grads, grad
|
|
77
|
+
|
|
78
|
+
# get update, which should create and return cloned grad
|
|
79
|
+
assert obj.updates is None
|
|
80
|
+
assert (update := obj.get_updates())[0] == 2.0, update
|
|
81
|
+
assert update is obj.updates
|
|
82
|
+
assert update is not obj.grads
|
|
83
|
+
assert obj.grads is not None
|
|
84
|
+
assert update[0] == obj.grads[0]
|
|
85
|
+
|
|
86
|
+
@torch.no_grad
|
|
87
|
+
def test_get_grad():
|
|
88
|
+
params = [torch.tensor(2.0, requires_grad=True)]
|
|
89
|
+
evaluated = False
|
|
90
|
+
|
|
91
|
+
def closure(backward=True):
|
|
92
|
+
# ensure closure only evaluates once
|
|
93
|
+
nonlocal evaluated
|
|
94
|
+
assert evaluated is False, 'closure was evaluated twice'
|
|
95
|
+
evaluated = True
|
|
96
|
+
|
|
97
|
+
loss = params[0]**2
|
|
98
|
+
if backward:
|
|
99
|
+
assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
|
|
100
|
+
params[0].grad = None
|
|
101
|
+
loss.backward()
|
|
102
|
+
else:
|
|
103
|
+
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
104
|
+
return loss
|
|
105
|
+
|
|
106
|
+
obj = Objective(params=params, closure=closure, model=None, current_step=0)
|
|
107
|
+
assert (grad := obj.get_grads())[0] == 4.0, grad
|
|
108
|
+
assert grad is obj.grads
|
|
109
|
+
|
|
110
|
+
assert obj.loss == 4.0
|
|
111
|
+
assert (loss := obj.get_loss(backward=False)) == 4.0, loss
|
|
112
|
+
assert (loss := obj.get_loss(backward=True)) == 4.0, loss
|
|
113
|
+
assert obj.loss_approx == 4.0
|
|
114
|
+
|
|
115
|
+
assert obj.updates is None, obj.updates
|
|
116
|
+
assert (update := obj.get_updates())[0] == 4.0, update
|
|
117
|
+
|
|
118
|
+
@torch.no_grad
|
|
119
|
+
def test_get_update():
|
|
120
|
+
params = [torch.tensor(2.0, requires_grad=True)]
|
|
121
|
+
evaluated = False
|
|
122
|
+
|
|
123
|
+
def closure(backward=True):
|
|
124
|
+
# ensure closure only evaluates once
|
|
125
|
+
nonlocal evaluated
|
|
126
|
+
assert evaluated is False, 'closure was evaluated twice'
|
|
127
|
+
evaluated = True
|
|
128
|
+
|
|
129
|
+
loss = params[0]**2
|
|
130
|
+
if backward:
|
|
131
|
+
assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
|
|
132
|
+
params[0].grad = None
|
|
133
|
+
loss.backward()
|
|
134
|
+
else:
|
|
135
|
+
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
136
|
+
return loss
|
|
137
|
+
|
|
138
|
+
obj = Objective(params=params, closure=closure, model=None, current_step=0)
|
|
139
|
+
assert obj.updates is None, obj.updates
|
|
140
|
+
assert (update := obj.get_updates())[0] == 4.0, update
|
|
141
|
+
assert update is obj.updates
|
|
142
|
+
|
|
143
|
+
assert (grad := obj.get_grads())[0] == 4.0, grad
|
|
144
|
+
assert grad is obj.grads
|
|
145
|
+
assert grad is not update
|
|
146
|
+
|
|
147
|
+
assert obj.loss == 4.0
|
|
148
|
+
assert (loss := obj.get_loss(backward=False)) == 4.0, loss
|
|
149
|
+
assert (loss := obj.get_loss(backward=True)) == 4.0, loss
|
|
150
|
+
assert obj.loss_approx == 4.0
|
|
151
|
+
|
|
152
|
+
assert (update := obj.get_updates())[0] == 4.0, update
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _assert_objectives_are_same_(o1: Objective, o2: Objective, clone_update: bool):
|
|
156
|
+
for k,v in o1.__dict__.items():
|
|
157
|
+
if not k.startswith('__'):
|
|
158
|
+
# if k == 'post_step_hooks': continue
|
|
159
|
+
if k == 'storage': continue
|
|
160
|
+
elif k == 'updates' and clone_update:
|
|
161
|
+
if o1.updates is None or o2.updates is None:
|
|
162
|
+
assert o1.updates is None and o2.updates is None, f'`{k}` attribute is not the same, {o1.updates = }, {o2.updates = }'
|
|
163
|
+
else:
|
|
164
|
+
assert (TensorList(o1.updates) == TensorList(o2.updates)).global_all()
|
|
165
|
+
assert o1.updates is not o2.updates
|
|
166
|
+
elif k == 'params':
|
|
167
|
+
for p1, p2 in zip(o1.params, o2.params):
|
|
168
|
+
assert p1.untyped_storage() == p2.untyped_storage()
|
|
169
|
+
else:
|
|
170
|
+
assert getattr(o2, k) is v, f'`{k}` attribute is not the same, {getattr(o1, k) = }, {getattr(o2, k) = }'
|
|
171
|
+
|
|
172
|
+
def test_var_clone():
|
|
173
|
+
model = torch.nn.Sequential(torch.nn.Linear(2,2), torch.nn.Linear(2,4))
|
|
174
|
+
def closure(backward): return 1
|
|
175
|
+
obj = Objective(params=list(model.parameters()), closure=closure, model=model, current_step=0)
|
|
176
|
+
|
|
177
|
+
_assert_objectives_are_same_(obj, obj.clone(clone_updates=False), clone_update=False)
|
|
178
|
+
_assert_objectives_are_same_(obj, obj.clone(clone_updates=True), clone_update=True)
|
|
179
|
+
|
|
180
|
+
obj.grads = TensorList(torch.randn(5))
|
|
181
|
+
_assert_objectives_are_same_(obj, obj.clone(clone_updates=False), clone_update=False)
|
|
182
|
+
_assert_objectives_are_same_(obj, obj.clone(clone_updates=True), clone_update=True)
|
|
183
|
+
|
|
184
|
+
obj.updates = TensorList(torch.randn(5) * 2)
|
|
185
|
+
obj.loss = torch.randn(1)
|
|
186
|
+
obj.loss_approx = obj.loss
|
|
187
|
+
_assert_objectives_are_same_(obj, obj.clone(clone_updates=False), clone_update=False)
|
|
188
|
+
_assert_objectives_are_same_(obj, obj.clone(clone_updates=True), clone_update=True)
|