torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
tests/test_vars.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import pytest
|
|
2
2
|
import torch
|
|
3
|
-
from torchzero.core.module import
|
|
3
|
+
from torchzero.core.module import Var
|
|
4
4
|
from torchzero.utils.tensorlist import TensorList
|
|
5
5
|
|
|
6
6
|
@torch.no_grad
|
|
7
|
-
def
|
|
7
|
+
def test_var_get_loss():
|
|
8
8
|
|
|
9
9
|
# ---------------------------- test that it works ---------------------------- #
|
|
10
10
|
params = [torch.tensor(2.0, requires_grad=True)]
|
|
@@ -26,20 +26,20 @@ def test_vars_get_loss():
|
|
|
26
26
|
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
27
27
|
return loss
|
|
28
28
|
|
|
29
|
-
|
|
29
|
+
var = Var(params=params, closure=closure_1, model=None, current_step=0)
|
|
30
30
|
|
|
31
|
-
assert
|
|
31
|
+
assert var.loss is None, var.loss
|
|
32
32
|
|
|
33
|
-
assert (loss :=
|
|
33
|
+
assert (loss := var.get_loss(backward=False)) == 4.0, loss
|
|
34
34
|
assert evaluated, evaluated
|
|
35
|
-
assert loss is
|
|
36
|
-
assert
|
|
37
|
-
assert
|
|
38
|
-
assert
|
|
35
|
+
assert loss is var.loss
|
|
36
|
+
assert var.loss == 4.0
|
|
37
|
+
assert var.loss_approx == 4.0
|
|
38
|
+
assert var.grad is None, var.grad
|
|
39
39
|
|
|
40
40
|
# reevaluate, which should just return already evaluated loss
|
|
41
|
-
assert (loss :=
|
|
42
|
-
assert
|
|
41
|
+
assert (loss := var.get_loss(backward=False)) == 4.0, loss
|
|
42
|
+
assert var.grad is None, var.grad
|
|
43
43
|
|
|
44
44
|
|
|
45
45
|
# ----------------------- test that backward=True works ---------------------- #
|
|
@@ -61,30 +61,30 @@ def test_vars_get_loss():
|
|
|
61
61
|
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
62
62
|
return loss
|
|
63
63
|
|
|
64
|
-
|
|
65
|
-
assert
|
|
66
|
-
assert (loss :=
|
|
67
|
-
assert
|
|
68
|
-
assert
|
|
64
|
+
var = Var(params=params, closure=closure_2, model=None, current_step=0)
|
|
65
|
+
assert var.grad is None, var.grad
|
|
66
|
+
assert (loss := var.get_loss(backward=True)) == 6.0, loss
|
|
67
|
+
assert var.grad is not None
|
|
68
|
+
assert var.grad[0] == 2.0, var.grad
|
|
69
69
|
|
|
70
70
|
# reevaluate, which should just return already evaluated loss
|
|
71
|
-
assert (loss :=
|
|
72
|
-
assert
|
|
71
|
+
assert (loss := var.get_loss(backward=True)) == 6.0, loss
|
|
72
|
+
assert var.grad[0] == 2.0, var.grad
|
|
73
73
|
|
|
74
74
|
# get grad, which should just return already evaluated grad
|
|
75
|
-
assert (grad :=
|
|
76
|
-
assert grad is
|
|
75
|
+
assert (grad := var.get_grad())[0] == 2.0, grad
|
|
76
|
+
assert grad is var.grad, grad
|
|
77
77
|
|
|
78
78
|
# get update, which should create and return cloned grad
|
|
79
|
-
assert
|
|
80
|
-
assert (update :=
|
|
81
|
-
assert update is
|
|
82
|
-
assert update is not
|
|
83
|
-
assert
|
|
84
|
-
assert update[0] ==
|
|
79
|
+
assert var.update is None
|
|
80
|
+
assert (update := var.get_update())[0] == 2.0, update
|
|
81
|
+
assert update is var.update
|
|
82
|
+
assert update is not var.grad
|
|
83
|
+
assert var.grad is not None
|
|
84
|
+
assert update[0] == var.grad[0]
|
|
85
85
|
|
|
86
86
|
@torch.no_grad
|
|
87
|
-
def
|
|
87
|
+
def test_var_get_grad():
|
|
88
88
|
params = [torch.tensor(2.0, requires_grad=True)]
|
|
89
89
|
evaluated = False
|
|
90
90
|
|
|
@@ -103,20 +103,20 @@ def test_vars_get_grad():
|
|
|
103
103
|
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
104
104
|
return loss
|
|
105
105
|
|
|
106
|
-
|
|
107
|
-
assert (grad :=
|
|
108
|
-
assert grad is
|
|
106
|
+
var = Var(params=params, closure=closure, model=None, current_step=0)
|
|
107
|
+
assert (grad := var.get_grad())[0] == 4.0, grad
|
|
108
|
+
assert grad is var.grad
|
|
109
109
|
|
|
110
|
-
assert
|
|
111
|
-
assert (loss :=
|
|
112
|
-
assert (loss :=
|
|
113
|
-
assert
|
|
110
|
+
assert var.loss == 4.0
|
|
111
|
+
assert (loss := var.get_loss(backward=False)) == 4.0, loss
|
|
112
|
+
assert (loss := var.get_loss(backward=True)) == 4.0, loss
|
|
113
|
+
assert var.loss_approx == 4.0
|
|
114
114
|
|
|
115
|
-
assert
|
|
116
|
-
assert (update :=
|
|
115
|
+
assert var.update is None, var.update
|
|
116
|
+
assert (update := var.get_update())[0] == 4.0, update
|
|
117
117
|
|
|
118
118
|
@torch.no_grad
|
|
119
|
-
def
|
|
119
|
+
def test_var_get_update():
|
|
120
120
|
params = [torch.tensor(2.0, requires_grad=True)]
|
|
121
121
|
evaluated = False
|
|
122
122
|
|
|
@@ -135,27 +135,28 @@ def test_vars_get_update():
|
|
|
135
135
|
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
136
136
|
return loss
|
|
137
137
|
|
|
138
|
-
|
|
139
|
-
assert
|
|
140
|
-
assert (update :=
|
|
141
|
-
assert update is
|
|
138
|
+
var = Var(params=params, closure=closure, model=None, current_step=0)
|
|
139
|
+
assert var.update is None, var.update
|
|
140
|
+
assert (update := var.get_update())[0] == 4.0, update
|
|
141
|
+
assert update is var.update
|
|
142
142
|
|
|
143
|
-
assert (grad :=
|
|
144
|
-
assert grad is
|
|
143
|
+
assert (grad := var.get_grad())[0] == 4.0, grad
|
|
144
|
+
assert grad is var.grad
|
|
145
145
|
assert grad is not update
|
|
146
146
|
|
|
147
|
-
assert
|
|
148
|
-
assert (loss :=
|
|
149
|
-
assert (loss :=
|
|
150
|
-
assert
|
|
147
|
+
assert var.loss == 4.0
|
|
148
|
+
assert (loss := var.get_loss(backward=False)) == 4.0, loss
|
|
149
|
+
assert (loss := var.get_loss(backward=True)) == 4.0, loss
|
|
150
|
+
assert var.loss_approx == 4.0
|
|
151
151
|
|
|
152
|
-
assert (update :=
|
|
152
|
+
assert (update := var.get_update())[0] == 4.0, update
|
|
153
153
|
|
|
154
154
|
|
|
155
|
-
def
|
|
155
|
+
def _assert_var_are_same_(v1: Var, v2: Var, clone_update: bool):
|
|
156
156
|
for k,v in v1.__dict__.items():
|
|
157
157
|
if not k.startswith('__'):
|
|
158
158
|
# if k == 'post_step_hooks': continue
|
|
159
|
+
if k == 'storage': continue
|
|
159
160
|
if k == 'update' and clone_update:
|
|
160
161
|
if v1.update is None or v2.update is None:
|
|
161
162
|
assert v1.update is None and v2.update is None, f'{k} is not the same, {v1 = }, {v2 = }'
|
|
@@ -165,20 +166,20 @@ def _assert_vars_are_same_(v1: Vars, v2: Vars, clone_update: bool):
|
|
|
165
166
|
else:
|
|
166
167
|
assert getattr(v2, k) is v, f'{k} is not the same, {v1 = }, {v2 = }'
|
|
167
168
|
|
|
168
|
-
def
|
|
169
|
+
def test_var_clone():
|
|
169
170
|
model = torch.nn.Sequential(torch.nn.Linear(2,2), torch.nn.Linear(2,4))
|
|
170
171
|
def closure(backward): return 1
|
|
171
|
-
|
|
172
|
+
var = Var(params=list(model.parameters()), closure=closure, model=model, current_step=0)
|
|
172
173
|
|
|
173
|
-
|
|
174
|
-
|
|
174
|
+
_assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
|
|
175
|
+
_assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
|
|
175
176
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
177
|
+
var.grad = TensorList(torch.randn(5))
|
|
178
|
+
_assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
|
|
179
|
+
_assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
|
|
179
180
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
181
|
+
var.update = TensorList(torch.randn(5) * 2)
|
|
182
|
+
var.loss = torch.randn(1)
|
|
183
|
+
var.loss_approx = var.loss
|
|
184
|
+
_assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
|
|
185
|
+
_assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
|
torchzero/core/__init__.py
CHANGED
|
@@ -1,3 +1,2 @@
|
|
|
1
|
-
from .module import
|
|
2
|
-
from .transform import Transform, TensorwiseTransform, Target,
|
|
3
|
-
from .preconditioner import Preconditioner, TensorwisePreconditioner
|
|
1
|
+
from .module import Var, Module, Modular, Chain, maybe_chain, Chainable
|
|
2
|
+
from .transform import Transform, TensorwiseTransform, Target, apply_transform
|
torchzero/core/module.py
CHANGED
|
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
|
|
|
3
3
|
from collections import ChainMap, defaultdict
|
|
4
4
|
from collections.abc import Callable, Iterable, MutableMapping, Sequence
|
|
5
5
|
from operator import itemgetter
|
|
6
|
-
from typing import Any, final, overload
|
|
6
|
+
from typing import Any, final, overload, Literal
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
@@ -14,6 +14,7 @@ from ..utils import (
|
|
|
14
14
|
_make_param_groups,
|
|
15
15
|
get_state_vals,
|
|
16
16
|
)
|
|
17
|
+
from ..utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
17
18
|
from ..utils.python_tools import flatten
|
|
18
19
|
|
|
19
20
|
|
|
@@ -29,8 +30,8 @@ def _closure_backward(closure, params, retain_graph, create_graph):
|
|
|
29
30
|
return loss
|
|
30
31
|
|
|
31
32
|
# region Vars
|
|
32
|
-
# -----------------------------------
|
|
33
|
-
class
|
|
33
|
+
# ----------------------------------- var ----------------------------------- #
|
|
34
|
+
class Var:
|
|
34
35
|
"""
|
|
35
36
|
Holds the state and context passed between optimizer modules during a step.
|
|
36
37
|
|
|
@@ -74,13 +75,13 @@ class Vars:
|
|
|
74
75
|
"""loss at a point near current point. This can be useful as some modules only calculate loss at perturbed points,
|
|
75
76
|
whereas some other modules require loss strictly at current point."""
|
|
76
77
|
|
|
77
|
-
self.post_step_hooks: list[Callable[[Modular,
|
|
78
|
+
self.post_step_hooks: list[Callable[[Modular, Var]]] = []
|
|
78
79
|
"""list of functions to be called after optimizer step.
|
|
79
80
|
The signature is:
|
|
80
81
|
|
|
81
82
|
.. code:: py
|
|
82
83
|
|
|
83
|
-
def hook(optimizer: Modular,
|
|
84
|
+
def hook(optimizer: Modular, var: Vars): ...
|
|
84
85
|
|
|
85
86
|
"""
|
|
86
87
|
|
|
@@ -109,8 +110,11 @@ class Vars:
|
|
|
109
110
|
self.skip_update: bool = False
|
|
110
111
|
"""if True, the parameters will not be updated"""
|
|
111
112
|
|
|
113
|
+
self.storage: dict = {}
|
|
114
|
+
"""Storage for any other data, such as hessian estimates, etc"""
|
|
115
|
+
|
|
112
116
|
def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
|
|
113
|
-
"""Returns the loss at current parameters, computing it if it hasn't been computed already and assigning :code:`
|
|
117
|
+
"""Returns the loss at current parameters, computing it if it hasn't been computed already and assigning :code:`var.loss`.
|
|
114
118
|
Do not call this at perturbed parameters. Backward always zeroes grads before recomputing."""
|
|
115
119
|
|
|
116
120
|
if self.loss is None:
|
|
@@ -143,7 +147,7 @@ class Vars:
|
|
|
143
147
|
|
|
144
148
|
def get_grad(self, retain_graph: bool | None = None, create_graph: bool = False) -> list[torch.Tensor]:
|
|
145
149
|
"""Returns the gradient at initial parameters, computing it if it hasn't been computed already and assigning
|
|
146
|
-
:code:`
|
|
150
|
+
:code:`var.grad` and potentially :code:`var.loss`. Do not call this at perturbed parameters."""
|
|
147
151
|
if self.grad is None:
|
|
148
152
|
if self.closure is None: raise RuntimeError("closure is None")
|
|
149
153
|
self.get_loss(backward=True, retain_graph=retain_graph, create_graph=create_graph) # evaluate and set self.loss and self.grad
|
|
@@ -152,15 +156,15 @@ class Vars:
|
|
|
152
156
|
return self.grad
|
|
153
157
|
|
|
154
158
|
def get_update(self) -> list[torch.Tensor]:
|
|
155
|
-
"""Returns the update. If update is None, it is initialized by cloning the gradients and assigning to :code:`
|
|
156
|
-
Computing the gradients may assign :code:`
|
|
159
|
+
"""Returns the update. If update is None, it is initialized by cloning the gradients and assigning to :code:`var.update`.
|
|
160
|
+
Computing the gradients may assign :code:`var.grad` and :code:`var.loss` if they haven't been computed.
|
|
157
161
|
Do not call this at perturbed parameters."""
|
|
158
162
|
if self.update is None: self.update = [g.clone() for g in self.get_grad()]
|
|
159
163
|
return self.update
|
|
160
164
|
|
|
161
165
|
def clone(self, clone_update: bool):
|
|
162
166
|
"""Creates a shallow copy of the Vars object, update can optionally be deep-copied (via :code:`torch.clone`)."""
|
|
163
|
-
copy =
|
|
167
|
+
copy = Var(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step)
|
|
164
168
|
|
|
165
169
|
if clone_update and self.update is not None:
|
|
166
170
|
copy.update = [u.clone() for u in self.update]
|
|
@@ -176,16 +180,17 @@ class Vars:
|
|
|
176
180
|
|
|
177
181
|
return copy
|
|
178
182
|
|
|
179
|
-
def update_attrs_from_clone_(self,
|
|
183
|
+
def update_attrs_from_clone_(self, var: "Var"):
|
|
180
184
|
"""Updates attributes of this `Vars` instance from a cloned instance.
|
|
181
185
|
Typically called after a child module has processed a cloned `Vars`
|
|
182
186
|
object. This propagates any newly computed loss or gradient values
|
|
183
187
|
from the child's context back to the parent `Vars` if the parent
|
|
184
188
|
didn't have them computed already.
|
|
185
189
|
"""
|
|
186
|
-
if self.loss is None: self.loss =
|
|
187
|
-
if self.loss_approx is None: self.loss_approx =
|
|
188
|
-
if self.grad is None: self.grad =
|
|
190
|
+
if self.loss is None: self.loss = var.loss
|
|
191
|
+
if self.loss_approx is None: self.loss_approx = var.loss_approx
|
|
192
|
+
if self.grad is None: self.grad = var.grad
|
|
193
|
+
self.storage.update(var.storage)
|
|
189
194
|
|
|
190
195
|
def zero_grad(self, set_to_none=True):
|
|
191
196
|
if set_to_none:
|
|
@@ -269,36 +274,36 @@ class Module(ABC):
|
|
|
269
274
|
return s
|
|
270
275
|
|
|
271
276
|
@overload
|
|
272
|
-
def get_settings(self, key: str, *,
|
|
273
|
-
|
|
277
|
+
def get_settings(self, params: Sequence[torch.Tensor], key: str, *,
|
|
278
|
+
cls: type[ListLike] = list) -> ListLike: ...
|
|
274
279
|
@overload
|
|
275
|
-
def get_settings(self, key: list[str] | tuple[str,...], *,
|
|
276
|
-
|
|
280
|
+
def get_settings(self, params: Sequence[torch.Tensor], key: list[str] | tuple[str,...], *,
|
|
281
|
+
cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
277
282
|
@overload
|
|
278
|
-
def get_settings(self, key: str, key2: str, *keys: str,
|
|
279
|
-
|
|
283
|
+
def get_settings(self, params: Sequence[torch.Tensor], key: str, key2: str, *keys: str,
|
|
284
|
+
cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
280
285
|
|
|
281
|
-
def get_settings(self, key: str | list[str] | tuple[str,...], key2: str | None = None,
|
|
282
|
-
|
|
286
|
+
def get_settings(self, params: Sequence[torch.Tensor], key: str | list[str] | tuple[str,...], key2: str | None = None,
|
|
287
|
+
*keys: str, cls: type[ListLike] = list) -> ListLike | list[ListLike]:
|
|
283
288
|
# if isinstance(params, Vars): params = params.params
|
|
284
289
|
return get_state_vals(self.settings, params, key, key2, *keys, must_exist=True, cls=cls) # pyright:ignore[reportArgumentType]
|
|
285
290
|
|
|
286
291
|
|
|
287
292
|
@overload
|
|
288
|
-
def get_state(self, key: str, *,
|
|
289
|
-
|
|
293
|
+
def get_state(self, params: Sequence[torch.Tensor], key: str, *,
|
|
294
|
+
must_exist: bool = False, init: Init = torch.zeros_like,
|
|
290
295
|
cls: type[ListLike] = list) -> ListLike: ...
|
|
291
296
|
@overload
|
|
292
|
-
def get_state(self, key: list[str] | tuple[str,...], *,
|
|
293
|
-
|
|
297
|
+
def get_state(self, params: Sequence[torch.Tensor], key: list[str] | tuple[str,...], *,
|
|
298
|
+
must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
294
299
|
cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
295
300
|
@overload
|
|
296
|
-
def get_state(self, key: str, key2: str, *keys: str,
|
|
297
|
-
|
|
301
|
+
def get_state(self, params: Sequence[torch.Tensor], key: str, key2: str, *keys: str,
|
|
302
|
+
must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
298
303
|
cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
299
304
|
|
|
300
|
-
def get_state(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
|
|
301
|
-
|
|
305
|
+
def get_state(self, params: Sequence[torch.Tensor], key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
|
|
306
|
+
must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
302
307
|
cls: type[ListLike] = list) -> ListLike | list[ListLike]:
|
|
303
308
|
"""Returns values of per-parameter state for a given key.
|
|
304
309
|
If key doesn't exist, create it with inits.
|
|
@@ -358,6 +363,26 @@ class Module(ABC):
|
|
|
358
363
|
# # if isinstance(params, Vars): params = params.params
|
|
359
364
|
# return itemgetter(*keys)(self.settings[params[0]])
|
|
360
365
|
|
|
366
|
+
def clear_state_keys(self, *keys:str):
|
|
367
|
+
for s in self.state.values():
|
|
368
|
+
for k in keys:
|
|
369
|
+
if k in s: del s[k]
|
|
370
|
+
|
|
371
|
+
@overload
|
|
372
|
+
def store(self, params: Sequence[torch.Tensor], keys: str, values: Sequence): ...
|
|
373
|
+
@overload
|
|
374
|
+
def store(self, params: Sequence[torch.Tensor], keys: Sequence[str], values: Sequence[Sequence]): ...
|
|
375
|
+
def store(self, params: Sequence[torch.Tensor], keys: str | Sequence[str], values: Sequence):
|
|
376
|
+
if isinstance(keys, str):
|
|
377
|
+
for p,v in zip(params, values):
|
|
378
|
+
state = self.state[p]
|
|
379
|
+
state[keys] = v
|
|
380
|
+
return
|
|
381
|
+
|
|
382
|
+
for p, *p_v in zip(params, *values):
|
|
383
|
+
state = self.state[p]
|
|
384
|
+
for k,v in zip(keys, p_v): state[k] = v
|
|
385
|
+
|
|
361
386
|
def state_dict(self):
|
|
362
387
|
"""state dict"""
|
|
363
388
|
packed_state = {id(k):v for k,v in self.state.items()}
|
|
@@ -403,23 +428,111 @@ class Module(ABC):
|
|
|
403
428
|
self._extra_unpack(state_dict['extra'])
|
|
404
429
|
|
|
405
430
|
# ---------------------------- OVERRIDABLE METHODS --------------------------- #
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
431
|
+
def step(self, var: Var) -> Var:
|
|
432
|
+
"""performs a step, returns new var but may update it in-place."""
|
|
433
|
+
self.update(var)
|
|
434
|
+
return self.apply(var)
|
|
435
|
+
|
|
436
|
+
def update(self, var:Var) -> Any:
|
|
437
|
+
"""Updates the internal state of this module. This should not modify `var.update`.
|
|
438
|
+
|
|
439
|
+
Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
|
|
440
|
+
such as ::code::`tz.m.Online`.
|
|
441
|
+
"""
|
|
442
|
+
|
|
443
|
+
def apply(self, var: Var) -> Var:
|
|
444
|
+
"""Applies this module to ``var.get_update()``. This should not modify the internal state of this module if possible."""
|
|
445
|
+
raise NotImplementedError(f"{self} doesn't implement the `apply` method.")
|
|
409
446
|
|
|
410
447
|
def reset(self):
|
|
411
|
-
"""Resets the internal state of the module (e.g. momentum)."""
|
|
448
|
+
"""Resets the internal state of the module (e.g. momentum). By default clears state and global state."""
|
|
412
449
|
# no complex logic is allowed there because this is overridden by many modules
|
|
413
450
|
# where super().reset() shouldn't be called
|
|
414
451
|
self.state.clear()
|
|
415
452
|
self.global_state.clear()
|
|
416
453
|
|
|
454
|
+
def reset_for_online(self):
|
|
455
|
+
"""resets only the intermediate state of this module, e.g. previous parameters and gradient."""
|
|
456
|
+
for c in self.children.values(): c.reset_for_online()
|
|
457
|
+
|
|
417
458
|
def _extra_pack(self):
|
|
418
459
|
return {}
|
|
419
460
|
|
|
420
461
|
def _extra_unpack(self, x):
|
|
421
462
|
pass
|
|
422
463
|
|
|
464
|
+
|
|
465
|
+
# ------------------------------ HELPER METHODS ------------------------------ #
|
|
466
|
+
@torch.no_grad
|
|
467
|
+
def Hvp(
|
|
468
|
+
self,
|
|
469
|
+
v: Sequence[torch.Tensor],
|
|
470
|
+
at_x0: bool,
|
|
471
|
+
var: Var,
|
|
472
|
+
rgrad: Sequence[torch.Tensor] | None,
|
|
473
|
+
hvp_method: Literal['autograd', 'forward', 'central'],
|
|
474
|
+
h: float,
|
|
475
|
+
normalize: bool,
|
|
476
|
+
retain_grad: bool,
|
|
477
|
+
):
|
|
478
|
+
"""
|
|
479
|
+
Returns ``(Hvp, rgrad)``. ``rgrad`` is gradient at current parameters, possibly with create_graph=True, or it may be None with ``hvp_method="central"``. Gradient is set to vars automatically if ``at_x0``, you can always access it with ``vars.get_grad()``
|
|
480
|
+
|
|
481
|
+
Single sample example:
|
|
482
|
+
|
|
483
|
+
.. code:: py
|
|
484
|
+
|
|
485
|
+
Hvp, _ = self.hvp(v, at_x0=True, rgrad=None, ..., retain_graph=False)
|
|
486
|
+
|
|
487
|
+
Multiple samples example:
|
|
488
|
+
|
|
489
|
+
.. code:: py
|
|
490
|
+
|
|
491
|
+
D = None
|
|
492
|
+
rgrad = None
|
|
493
|
+
for i in range(n_samples):
|
|
494
|
+
v = [torch.randn_like(p) for p in params]
|
|
495
|
+
Hvp, rgrad = self.hvp(v, at_x0=True, rgrad=rgrad, ..., retain_graph=i < n_samples-1)
|
|
496
|
+
|
|
497
|
+
if D is None: D = Hvp
|
|
498
|
+
else: torch._foreach_add_(D, Hvp)
|
|
499
|
+
|
|
500
|
+
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
501
|
+
Args:
|
|
502
|
+
v (Sequence[torch.Tensor]): vector in hessian-vector product
|
|
503
|
+
at_x0 (bool): whether this is being called at original or perturbed parameters.
|
|
504
|
+
var (Var): Var
|
|
505
|
+
rgrad (Sequence[torch.Tensor] | None): pass None initially, then pass what this returns.
|
|
506
|
+
hvp_method (str): hvp method.
|
|
507
|
+
h (float): finite difference step size
|
|
508
|
+
normalize (bool): whether to normalize v for finite difference
|
|
509
|
+
retain_grad (bool): retain grad
|
|
510
|
+
"""
|
|
511
|
+
# get grad
|
|
512
|
+
if rgrad is None and hvp_method in ('autograd', 'forward'):
|
|
513
|
+
if at_x0: rgrad = var.get_grad(create_graph = hvp_method=='autograd')
|
|
514
|
+
else:
|
|
515
|
+
if var.closure is None: raise RuntimeError("Closure is required to calculate HVp")
|
|
516
|
+
with torch.enable_grad():
|
|
517
|
+
loss = var.closure()
|
|
518
|
+
rgrad = torch.autograd.grad(loss, var.params, create_graph = hvp_method=='autograd')
|
|
519
|
+
|
|
520
|
+
if hvp_method == 'autograd':
|
|
521
|
+
assert rgrad is not None
|
|
522
|
+
Hvp = hvp(var.params, rgrad, v, retain_graph=retain_grad)
|
|
523
|
+
|
|
524
|
+
elif hvp_method == 'forward':
|
|
525
|
+
assert rgrad is not None
|
|
526
|
+
loss, Hvp = hvp_fd_forward(var.closure, var.params, v, h=h, g_0=rgrad, normalize=normalize)
|
|
527
|
+
|
|
528
|
+
elif hvp_method == 'central':
|
|
529
|
+
loss, Hvp = hvp_fd_central(var.closure, var.params, v, h=h, normalize=normalize)
|
|
530
|
+
|
|
531
|
+
else:
|
|
532
|
+
raise ValueError(hvp_method)
|
|
533
|
+
|
|
534
|
+
return Hvp, rgrad
|
|
535
|
+
|
|
423
536
|
# endregion
|
|
424
537
|
|
|
425
538
|
Chainable = Module | Sequence[Module]
|
|
@@ -440,6 +553,21 @@ def unroll_modules(*modules: Chainable) -> list[Module]:
|
|
|
440
553
|
|
|
441
554
|
# region Modular
|
|
442
555
|
# ---------------------------------- Modular --------------------------------- #
|
|
556
|
+
|
|
557
|
+
class _EvalCounterClosure:
|
|
558
|
+
"""keeps track of how many times closure has been evaluated"""
|
|
559
|
+
__slots__ = ("modular", "closure")
|
|
560
|
+
def __init__(self, modular: "Modular", closure):
|
|
561
|
+
self.modular = modular
|
|
562
|
+
self.closure = closure
|
|
563
|
+
|
|
564
|
+
def __call__(self, *args, **kwargs):
|
|
565
|
+
if self.closure is None:
|
|
566
|
+
raise RuntimeError("One of the modules requires closure to be passed to the step method")
|
|
567
|
+
|
|
568
|
+
self.modular.num_evaluations += 1
|
|
569
|
+
return self.closure(*args, **kwargs)
|
|
570
|
+
|
|
443
571
|
# have to inherit from Modular to support lr schedulers
|
|
444
572
|
# although Accelerate doesn't work due to converting param_groups to a dict
|
|
445
573
|
class Modular(torch.optim.Optimizer):
|
|
@@ -496,7 +624,10 @@ class Modular(torch.optim.Optimizer):
|
|
|
496
624
|
# self.add_param_group(param_group)
|
|
497
625
|
|
|
498
626
|
self.current_step = 0
|
|
499
|
-
"""
|
|
627
|
+
"""global step counter for the optimizer."""
|
|
628
|
+
|
|
629
|
+
self.num_evaluations = 0
|
|
630
|
+
"""number of times the objective has been evaluated (number of closure calls or number of steps if closure is None)."""
|
|
500
631
|
|
|
501
632
|
def add_param_group(self, param_group: dict[str, Any]):
|
|
502
633
|
proc_param_group = _make_param_groups([param_group], differentiable=False)[0]
|
|
@@ -556,13 +687,14 @@ class Modular(torch.optim.Optimizer):
|
|
|
556
687
|
if not p.requires_grad: continue
|
|
557
688
|
for map in self._per_parameter_global_settings[p]: map.update(settings)
|
|
558
689
|
|
|
559
|
-
# create
|
|
690
|
+
# create var
|
|
560
691
|
params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
|
|
561
|
-
|
|
692
|
+
var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step)
|
|
562
693
|
|
|
563
694
|
# if closure is None, assume backward has been called and gather grads
|
|
564
695
|
if closure is None:
|
|
565
|
-
|
|
696
|
+
var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
697
|
+
self.num_evaluations += 1
|
|
566
698
|
|
|
567
699
|
last_module = self.modules[-1]
|
|
568
700
|
last_lr = last_module.defaults.get('lr', None)
|
|
@@ -570,27 +702,27 @@ class Modular(torch.optim.Optimizer):
|
|
|
570
702
|
|
|
571
703
|
# step
|
|
572
704
|
for i, module in enumerate(self.modules):
|
|
573
|
-
if i!=0:
|
|
705
|
+
if i!=0: var = var.clone(clone_update=False)
|
|
574
706
|
|
|
575
707
|
# last module, or next to last module before lr
|
|
576
708
|
if (i == n_modules - 1) or ((i == n_modules - 2) and (last_lr is not None)):
|
|
577
|
-
if module.children:
|
|
578
|
-
else:
|
|
579
|
-
if last_lr is not None:
|
|
709
|
+
if module.children: var.nested_is_last = True
|
|
710
|
+
else: var.is_last = True
|
|
711
|
+
if last_lr is not None: var.last_module_lrs = [last_module.settings[p]['lr'] for p in var.params]
|
|
580
712
|
|
|
581
|
-
|
|
582
|
-
if
|
|
713
|
+
var = module.step(var)
|
|
714
|
+
if var.stop: break
|
|
583
715
|
|
|
584
716
|
# apply update
|
|
585
|
-
if not
|
|
717
|
+
if not var.skip_update:
|
|
586
718
|
with torch.no_grad():
|
|
587
|
-
torch._foreach_sub_(params,
|
|
719
|
+
torch._foreach_sub_(params, var.get_update())
|
|
588
720
|
|
|
589
|
-
for hook in
|
|
590
|
-
hook(self,
|
|
721
|
+
for hook in var.post_step_hooks:
|
|
722
|
+
hook(self, var)
|
|
591
723
|
|
|
592
724
|
self.current_step += 1
|
|
593
|
-
return
|
|
725
|
+
return var.loss if var.loss is not None else var.loss_approx
|
|
594
726
|
|
|
595
727
|
def __repr__(self):
|
|
596
728
|
return f'Modular({", ".join(str(m) for m in self.modules)})'
|
|
@@ -606,11 +738,11 @@ class Chain(Module):
|
|
|
606
738
|
for i, module in enumerate(flat_modules):
|
|
607
739
|
self.set_child(f'module_{i}', module)
|
|
608
740
|
|
|
609
|
-
def step(self,
|
|
741
|
+
def step(self, var):
|
|
610
742
|
for i in range(len(self.children)):
|
|
611
|
-
|
|
612
|
-
if
|
|
613
|
-
return
|
|
743
|
+
var = self.children[f'module_{i}'].step(var)
|
|
744
|
+
if var.stop: break
|
|
745
|
+
return var
|
|
614
746
|
|
|
615
747
|
def __repr__(self):
|
|
616
748
|
s = self.__class__.__name__
|