torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +43 -33
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +48 -52
- torchzero/core/module.py +130 -50
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +2 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +15 -12
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +32 -32
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +50 -48
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +10 -10
- torchzero/modules/quasi_newton/sg2.py +19 -19
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +49 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +57 -90
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +157 -177
- torchzero/modules/second_order/rsn.py +106 -96
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +10 -10
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +93 -69
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
torchzero/core/modular.py
CHANGED
|
@@ -1,27 +1,16 @@
|
|
|
1
1
|
|
|
2
2
|
import warnings
|
|
3
|
-
from
|
|
4
|
-
from collections import
|
|
5
|
-
from
|
|
6
|
-
from operator import itemgetter
|
|
7
|
-
from typing import TYPE_CHECKING, Any, Literal, cast, final, overload
|
|
3
|
+
from collections import ChainMap
|
|
4
|
+
from collections.abc import MutableMapping
|
|
5
|
+
from typing import Any
|
|
8
6
|
|
|
9
7
|
import torch
|
|
10
8
|
|
|
11
|
-
from ..utils import
|
|
12
|
-
Init,
|
|
13
|
-
ListLike,
|
|
14
|
-
Params,
|
|
15
|
-
_make_param_groups,
|
|
16
|
-
get_state_vals,
|
|
17
|
-
vec_to_tensors,
|
|
18
|
-
)
|
|
19
|
-
from ..utils.derivatives import flatten_jacobian, hvp, hvp_fd_central, hvp_fd_forward
|
|
20
|
-
from ..utils.linalg.linear_operator import LinearOperator
|
|
21
|
-
from ..utils.python_tools import flatten
|
|
22
|
-
from .module import Chainable, Module
|
|
23
|
-
from .var import Var
|
|
9
|
+
from ..utils.params import Params, _make_param_groups
|
|
24
10
|
from .functional import step
|
|
11
|
+
from .module import Chainable, Module
|
|
12
|
+
from .objective import Objective
|
|
13
|
+
|
|
25
14
|
|
|
26
15
|
class _EvalCounterClosure:
|
|
27
16
|
"""keeps track of how many times closure has been evaluated, and sets closure return"""
|
|
@@ -32,7 +21,7 @@ class _EvalCounterClosure:
|
|
|
32
21
|
|
|
33
22
|
def __call__(self, *args, **kwargs):
|
|
34
23
|
if self.closure is None:
|
|
35
|
-
raise RuntimeError("
|
|
24
|
+
raise RuntimeError("closure is None in _EvalCounterClosure, and this can't happen")
|
|
36
25
|
|
|
37
26
|
v = self.closure(*args, **kwargs)
|
|
38
27
|
|
|
@@ -44,17 +33,17 @@ class _EvalCounterClosure:
|
|
|
44
33
|
return v
|
|
45
34
|
|
|
46
35
|
|
|
47
|
-
def
|
|
48
|
-
|
|
36
|
+
def flatten_modules(*modules: Chainable) -> list[Module]:
|
|
37
|
+
flat = []
|
|
49
38
|
|
|
50
39
|
for m in modules:
|
|
51
40
|
if isinstance(m, Module):
|
|
52
|
-
|
|
53
|
-
|
|
41
|
+
flat.append(m)
|
|
42
|
+
flat.extend(flatten_modules(list(m.children.values())))
|
|
54
43
|
else:
|
|
55
|
-
|
|
44
|
+
flat.extend(flatten_modules(*m))
|
|
56
45
|
|
|
57
|
-
return
|
|
46
|
+
return flat
|
|
58
47
|
|
|
59
48
|
|
|
60
49
|
# have to inherit from Modular to support lr schedulers
|
|
@@ -83,7 +72,7 @@ class Modular(torch.optim.Optimizer):
|
|
|
83
72
|
self.modules = modules
|
|
84
73
|
"""Top-level modules providedduring initialization."""
|
|
85
74
|
|
|
86
|
-
self.
|
|
75
|
+
self.flat_modules = flatten_modules(self.modules)
|
|
87
76
|
"""A flattened list of all modules including all children."""
|
|
88
77
|
|
|
89
78
|
param_groups = _make_param_groups(params, differentiable=False)
|
|
@@ -92,7 +81,7 @@ class Modular(torch.optim.Optimizer):
|
|
|
92
81
|
Each element in the list is ChainDict's 2nd map of a module."""
|
|
93
82
|
|
|
94
83
|
# make sure there is no more than a single learning rate module
|
|
95
|
-
lr_modules = [m for m in self.
|
|
84
|
+
lr_modules = [m for m in self.flat_modules if 'lr' in m.defaults]
|
|
96
85
|
if len(lr_modules) > 1:
|
|
97
86
|
warnings.warn(f'multiple learning rate modules detected: {lr_modules}. This may lead to componding of learning rate multiplication with per-parameter learning rates and schedulers.')
|
|
98
87
|
|
|
@@ -100,13 +89,13 @@ class Modular(torch.optim.Optimizer):
|
|
|
100
89
|
for group in param_groups:
|
|
101
90
|
for k in group:
|
|
102
91
|
if k in ('params', 'lr'): continue
|
|
103
|
-
modules_with_k = [m for m in self.
|
|
92
|
+
modules_with_k = [m for m in self.flat_modules if k in m.defaults and k not in m._overridden_keys]
|
|
104
93
|
if len(modules_with_k) > 1:
|
|
105
94
|
warnings.warn(f'`params` has a `{k}` key, and multiple modules have that key: {modules_with_k}. If you intended to only set `{k}` to one of them, use `module.set_param_groups(params)`')
|
|
106
95
|
|
|
107
96
|
# defaults for schedulers
|
|
108
97
|
defaults = {}
|
|
109
|
-
for m in self.
|
|
98
|
+
for m in self.flat_modules: defaults.update(m.defaults)
|
|
110
99
|
super().__init__(param_groups, defaults=defaults)
|
|
111
100
|
|
|
112
101
|
# note - this is what super().__init__(param_groups, defaults=defaults) does:
|
|
@@ -146,7 +135,7 @@ class Modular(torch.optim.Optimizer):
|
|
|
146
135
|
|
|
147
136
|
for p in proc_param_group['params']:
|
|
148
137
|
# updates global per-parameter setting overrides (medium priority)
|
|
149
|
-
self._per_parameter_global_settings[p] = [m.settings[p].maps[1] for m in self.
|
|
138
|
+
self._per_parameter_global_settings[p] = [m.settings[p].maps[1] for m in self.flat_modules]
|
|
150
139
|
|
|
151
140
|
def state_dict(self):
|
|
152
141
|
all_params = [p for g in self.param_groups for p in g['params']]
|
|
@@ -163,7 +152,7 @@ class Modular(torch.optim.Optimizer):
|
|
|
163
152
|
"params": all_params,
|
|
164
153
|
"groups": groups,
|
|
165
154
|
"defaults": self.defaults,
|
|
166
|
-
"modules": {i: m.state_dict() for i, m in enumerate(self.
|
|
155
|
+
"modules": {i: m.state_dict() for i, m in enumerate(self.flat_modules)}
|
|
167
156
|
}
|
|
168
157
|
return state_dict
|
|
169
158
|
|
|
@@ -183,7 +172,7 @@ class Modular(torch.optim.Optimizer):
|
|
|
183
172
|
self.add_param_group(group)
|
|
184
173
|
|
|
185
174
|
id_to_tensor = {state_dict['idx_to_id'][i]: p for i,p in enumerate(state_dict['params'])}
|
|
186
|
-
for m, sd in zip(self.
|
|
175
|
+
for m, sd in zip(self.flat_modules, state_dict['modules'].values()):
|
|
187
176
|
m._load_state_dict(sd, id_to_tensor)
|
|
188
177
|
|
|
189
178
|
|
|
@@ -201,35 +190,42 @@ class Modular(torch.optim.Optimizer):
|
|
|
201
190
|
if not p.requires_grad: continue
|
|
202
191
|
for map in self._per_parameter_global_settings[p]: map.update(settings)
|
|
203
192
|
|
|
204
|
-
# create
|
|
193
|
+
# create Objective
|
|
205
194
|
params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
|
|
206
|
-
var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step, modular=self, loss=loss, storage=kwargs)
|
|
207
195
|
|
|
208
|
-
|
|
209
|
-
if closure is None:
|
|
210
|
-
|
|
211
|
-
self.num_evaluations += 1
|
|
196
|
+
counter_closure = None
|
|
197
|
+
if closure is not None:
|
|
198
|
+
counter_closure = _EvalCounterClosure(self, closure)
|
|
212
199
|
|
|
213
|
-
|
|
200
|
+
objective = Objective(
|
|
201
|
+
params=params, closure=counter_closure, model=self.model,
|
|
202
|
+
current_step=self.current_step, modular=self, loss=loss, storage=kwargs
|
|
203
|
+
)
|
|
214
204
|
|
|
215
|
-
# step
|
|
216
|
-
|
|
205
|
+
# step with all modules
|
|
206
|
+
objective = step(objective, self.modules)
|
|
217
207
|
|
|
218
|
-
# apply update
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
208
|
+
# apply update to parameters unless `objective.skip_update = True`
|
|
209
|
+
# this does:
|
|
210
|
+
# if not objective.skip_update:
|
|
211
|
+
# torch._foreach_sub_(objective.params, objective.get_updates())
|
|
212
|
+
objective.update_parameters()
|
|
222
213
|
|
|
223
214
|
# update attributes
|
|
224
|
-
self.attrs.update(
|
|
225
|
-
if
|
|
226
|
-
|
|
227
|
-
# hooks
|
|
228
|
-
for hook in var.post_step_hooks:
|
|
229
|
-
hook(self, var)
|
|
215
|
+
self.attrs.update(objective.attrs)
|
|
216
|
+
if objective.should_terminate is not None:
|
|
217
|
+
self.should_terminate = objective.should_terminate
|
|
230
218
|
|
|
231
219
|
self.current_step += 1
|
|
232
|
-
|
|
220
|
+
|
|
221
|
+
# apply hooks
|
|
222
|
+
# this does:
|
|
223
|
+
# for hook in objective.post_step_hooks:
|
|
224
|
+
# hook(objective, modules)
|
|
225
|
+
objective.apply_post_step_hooks(self.modules)
|
|
226
|
+
|
|
227
|
+
# return the first closure evaluation return
|
|
228
|
+
# could return loss if it was passed but that's pointless
|
|
233
229
|
return self._closure_return
|
|
234
230
|
|
|
235
231
|
def __repr__(self):
|
torchzero/core/module.py
CHANGED
|
@@ -1,24 +1,18 @@
|
|
|
1
1
|
import warnings
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from collections import ChainMap, defaultdict
|
|
4
|
-
from collections.abc import Callable, Iterable,
|
|
5
|
-
from
|
|
6
|
-
from typing import Any, Literal, cast, final, overload
|
|
4
|
+
from collections.abc import Callable, Iterable, Sequence
|
|
5
|
+
from typing import Any, overload, TYPE_CHECKING
|
|
7
6
|
|
|
8
7
|
import torch
|
|
9
8
|
|
|
10
|
-
from ..
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
)
|
|
18
|
-
from ..utils.derivatives import flatten_jacobian, hvp, hvp_fd_central, hvp_fd_forward
|
|
19
|
-
from ..utils.linalg.linear_operator import LinearOperator
|
|
20
|
-
from ..utils.python_tools import flatten
|
|
21
|
-
from .var import Var
|
|
9
|
+
from ..linalg.linear_operator import LinearOperator
|
|
10
|
+
from ..utils.optimizer import Init, ListLike, get_state_vals
|
|
11
|
+
from ..utils.params import Params, _make_param_groups
|
|
12
|
+
from .functional import step_tensors
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from .objective import Objective
|
|
22
16
|
|
|
23
17
|
|
|
24
18
|
class Module(ABC):
|
|
@@ -36,6 +30,7 @@ class Module(ABC):
|
|
|
36
30
|
"""
|
|
37
31
|
def __init__(self, defaults: dict[str, Any] | None = None):
|
|
38
32
|
if defaults is None: defaults = {}
|
|
33
|
+
if any(isinstance(v, Module) for v in defaults.values()): raise RuntimeError("Passed a module to defaults")
|
|
39
34
|
self.defaults: dict[str, Any] = defaults
|
|
40
35
|
|
|
41
36
|
# settings are stored like state in per-tensor defaultdict, with per-parameter overrides possible
|
|
@@ -55,7 +50,7 @@ class Module(ABC):
|
|
|
55
50
|
"""A dictionary of child modules."""
|
|
56
51
|
|
|
57
52
|
self._overridden_keys = set()
|
|
58
|
-
"""tracks keys overridden with
|
|
53
|
+
"""tracks keys overridden with ``set_param_groups``, only used to not give a warning"""
|
|
59
54
|
|
|
60
55
|
|
|
61
56
|
def set_param_groups(self, param_groups: Params):
|
|
@@ -71,7 +66,12 @@ class Module(ABC):
|
|
|
71
66
|
self.settings[param].maps[0].update(settings) # set module-specific per-parameter settings
|
|
72
67
|
return self
|
|
73
68
|
|
|
74
|
-
def set_child(self, key: str, module: "Module | Sequence[Module]"):
|
|
69
|
+
def set_child(self, key: str, module: "Module | Sequence[Module] | None"):
|
|
70
|
+
if key in self.children:
|
|
71
|
+
warnings.warn(f"set_child overwriting child `{key}`")
|
|
72
|
+
|
|
73
|
+
if module is None: return
|
|
74
|
+
|
|
75
75
|
from .chain import maybe_chain
|
|
76
76
|
self.children[key] = maybe_chain(module)
|
|
77
77
|
|
|
@@ -85,6 +85,62 @@ class Module(ABC):
|
|
|
85
85
|
def get_children_sequence(self, prefix = 'module_'):
|
|
86
86
|
return [self.children[f'{prefix}{i}'] for i in range(len(self.children)) if f'{prefix}{i}' in self.children]
|
|
87
87
|
|
|
88
|
+
def inner_step(
|
|
89
|
+
self,
|
|
90
|
+
key: str,
|
|
91
|
+
objective: "Objective",
|
|
92
|
+
must_exist: bool = True,
|
|
93
|
+
) -> "Objective":
|
|
94
|
+
"""Passes ``objective`` to child and returns it."""
|
|
95
|
+
child = self.children.get(key, None)
|
|
96
|
+
|
|
97
|
+
if child is None:
|
|
98
|
+
if must_exist: raise KeyError(f"child `{key}` doesn't exist")
|
|
99
|
+
return objective
|
|
100
|
+
|
|
101
|
+
return child.step(objective)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def inner_step_tensors(
|
|
105
|
+
self,
|
|
106
|
+
key: str,
|
|
107
|
+
tensors: list[torch.Tensor],
|
|
108
|
+
clone: bool,
|
|
109
|
+
params: Iterable[torch.Tensor] | None = None,
|
|
110
|
+
grads: Sequence[torch.Tensor] | None = None,
|
|
111
|
+
loss: torch.Tensor | None = None,
|
|
112
|
+
closure: Callable | None = None,
|
|
113
|
+
objective: "Objective | None" = None,
|
|
114
|
+
must_exist: bool = True
|
|
115
|
+
) -> list[torch.Tensor]:
|
|
116
|
+
"""Steps with child module. Can be used to apply transforms to any internal buffers.
|
|
117
|
+
|
|
118
|
+
If ``objective`` is specified, other attributes shouldn't to be specified.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
key (str): Child module key.
|
|
122
|
+
tensors (Sequence[torch.Tensor]): tensors to pass to child module.
|
|
123
|
+
clone (bool):
|
|
124
|
+
If ``key`` exists, whether to clone ``tensors`` to avoid modifying buffers in-place.
|
|
125
|
+
If ``key`` doesn't exist, ``tensors`` are always returned without cloning
|
|
126
|
+
params (Iterable[torch.Tensor] | None, optional): pass None if ``tensors`` have different shape. Defaults to None.
|
|
127
|
+
grads (Sequence[torch.Tensor] | None, optional): grads. Defaults to None.
|
|
128
|
+
loss (torch.Tensor | None, optional): loss. Defaults to None.
|
|
129
|
+
closure (Callable | None, optional): closure. Defaults to None.
|
|
130
|
+
must_exist (bool, optional): if True, if ``key`` doesn't exist, raises ``KeyError``. Defaults to True.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
child = self.children.get(key, None)
|
|
134
|
+
|
|
135
|
+
if child is None:
|
|
136
|
+
if must_exist: raise KeyError(f"child `{key}` doesn't exist")
|
|
137
|
+
return tensors
|
|
138
|
+
|
|
139
|
+
if clone: tensors = [t.clone() for t in tensors]
|
|
140
|
+
return step_tensors(modules=child, tensors=tensors, params=params, grads=grads,
|
|
141
|
+
loss=loss, closure=closure, objective=objective)
|
|
142
|
+
|
|
143
|
+
|
|
88
144
|
def __repr__(self):
|
|
89
145
|
s = self.__class__.__name__
|
|
90
146
|
if self.children:
|
|
@@ -106,7 +162,6 @@ class Module(ABC):
|
|
|
106
162
|
|
|
107
163
|
def get_settings(self, params: Sequence[torch.Tensor], key: str | list[str] | tuple[str,...], key2: str | None = None,
|
|
108
164
|
*keys: str, cls: type[ListLike] = list) -> ListLike | list[ListLike]:
|
|
109
|
-
# if isinstance(params, Vars): params = params.params
|
|
110
165
|
return get_state_vals(self.settings, params, key, key2, *keys, must_exist=True, cls=cls) # pyright:ignore[reportArgumentType]
|
|
111
166
|
|
|
112
167
|
|
|
@@ -176,13 +231,8 @@ class Module(ABC):
|
|
|
176
231
|
- if state_keys has multiple keys and keys has a single key, return cls.
|
|
177
232
|
- if state_keys has multiple keys and keys has multiple keys, return list of cls.
|
|
178
233
|
"""
|
|
179
|
-
# if isinstance(params, Vars): params = params.params
|
|
180
234
|
return get_state_vals(self.state, params, key, key2, *keys, must_exist=must_exist, init=init, cls=cls) # pyright:ignore[reportArgumentType]
|
|
181
235
|
|
|
182
|
-
# def first_setting(self, *keys:str, params:Sequence[torch.Tensor]):
|
|
183
|
-
# # if isinstance(params, Vars): params = params.params
|
|
184
|
-
# return itemgetter(*keys)(self.settings[params[0]])
|
|
185
|
-
|
|
186
236
|
def clear_state_keys(self, *keys:str):
|
|
187
237
|
for s in self.state.values():
|
|
188
238
|
for k in keys:
|
|
@@ -248,36 +298,73 @@ class Module(ABC):
|
|
|
248
298
|
# extra info
|
|
249
299
|
self._extra_unpack(state_dict['extra'])
|
|
250
300
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
301
|
+
def get_generator(self, device: torch.types.Device, seed: int | None):
|
|
302
|
+
"""If ``seed=None``, returns ``None``.
|
|
303
|
+
|
|
304
|
+
Otherwise, if generator on this device and with this seed hasn't been created,
|
|
305
|
+
creates it and stores in global state.
|
|
306
|
+
|
|
307
|
+
Returns ``torch.Generator``."""
|
|
308
|
+
if seed is None: return None
|
|
256
309
|
|
|
257
|
-
|
|
258
|
-
|
|
310
|
+
if device is None: device_obj = torch.get_default_device()
|
|
311
|
+
else: device_obj = torch.device(device)
|
|
312
|
+
key = f"__generator-{seed}-{device_obj.type}:{device_obj.index}"
|
|
313
|
+
|
|
314
|
+
if key not in self.global_state:
|
|
315
|
+
self.global_state[key] = torch.Generator(device).manual_seed(seed)
|
|
316
|
+
|
|
317
|
+
return self.global_state[key]
|
|
318
|
+
|
|
319
|
+
def increment_counter(self, key: str, start: int):
|
|
320
|
+
"""first value is ``start``"""
|
|
321
|
+
value = self.global_state.get(key, start - 1) + 1
|
|
322
|
+
self.global_state[key] = value
|
|
323
|
+
return value
|
|
324
|
+
|
|
325
|
+
# ---------------------------- OVERRIDABLE METHODS --------------------------- #
|
|
326
|
+
def update(self, objective:"Objective") -> None:
|
|
327
|
+
"""Updates internal state of this module. This should not modify ``objective.update``.
|
|
259
328
|
|
|
260
329
|
Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
|
|
261
|
-
such as ``tz.m.Online`` or trust regions. Alternatively,
|
|
330
|
+
such as ``tz.m.Online`` or trust regions. Alternatively, define all logic within the ``apply`` method.
|
|
331
|
+
|
|
332
|
+
``update`` is guaranteed to be called at least once before ``apply``.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
objective (Objective): ``Objective`` object
|
|
262
336
|
"""
|
|
263
337
|
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
338
|
+
@abstractmethod
|
|
339
|
+
def apply(self, objective: "Objective") -> "Objective":
|
|
340
|
+
"""Updates ``objective`` using the internal state of this module.
|
|
341
|
+
|
|
342
|
+
If ``update`` method is defined, ``apply`` shouldn't modify the internal state of this module if possible.
|
|
267
343
|
|
|
268
344
|
Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
|
|
269
|
-
such as ``tz.m.Online`` or trust regions. Alternatively,
|
|
345
|
+
such as ``tz.m.Online`` or trust regions. Alternatively, define all logic within the ``apply`` method.
|
|
346
|
+
|
|
347
|
+
``update`` is guaranteed to be called at least once before ``apply``.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
objective (Objective): ``Objective`` object
|
|
270
351
|
"""
|
|
271
|
-
|
|
352
|
+
# if apply is empty, it should be defined explicitly.
|
|
353
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement `apply`.")
|
|
354
|
+
|
|
355
|
+
def step(self, objective: "Objective") -> "Objective":
|
|
356
|
+
"""Perform a step with this module. Calls ``update``, then ``apply``."""
|
|
357
|
+
self.update(objective)
|
|
358
|
+
return self.apply(objective)
|
|
272
359
|
|
|
273
|
-
def get_H(self,
|
|
360
|
+
def get_H(self, objective: "Objective") -> LinearOperator | None:
|
|
274
361
|
"""returns a ``LinearOperator`` corresponding to hessian or hessian approximation.
|
|
275
362
|
The hessian approximation is assumed to be for all parameters concatenated to a vector."""
|
|
276
363
|
# if this method is not defined it searches in children
|
|
277
364
|
# this should be overwritten to return None if child params are different from this modules params
|
|
278
365
|
H = None
|
|
279
366
|
for k,v in self.children.items():
|
|
280
|
-
H_v = v.get_H(
|
|
367
|
+
H_v = v.get_H(objective)
|
|
281
368
|
|
|
282
369
|
if (H is not None) and (H_v is not None):
|
|
283
370
|
raise RuntimeError(f"Two children of {self} have a hessian, second one is {k}={v}")
|
|
@@ -307,21 +394,14 @@ class Module(ABC):
|
|
|
307
394
|
"""
|
|
308
395
|
for c in self.children.values(): c.reset_for_online()
|
|
309
396
|
|
|
310
|
-
def _extra_pack(self):
|
|
311
|
-
"""extra information to store in state_dict of this optimizer.
|
|
312
|
-
Will be passed to ``_extra_unpack`` when loading the state_dict
|
|
397
|
+
def _extra_pack(self) -> dict:
|
|
398
|
+
"""extra information to store in ``state_dict`` of this optimizer.
|
|
399
|
+
Will be passed to ``_extra_unpack`` when loading the ``state_dict``."""
|
|
313
400
|
return {}
|
|
314
401
|
|
|
315
|
-
def _extra_unpack(self,
|
|
316
|
-
"""``_extra_pack`` return will be passed to this method when loading state_dict
|
|
402
|
+
def _extra_unpack(self, d: dict):
|
|
403
|
+
"""``_extra_pack`` return will be passed to this method when loading ``state_dict``.
|
|
317
404
|
This method is called after loading the rest of the state dict"""
|
|
318
405
|
|
|
319
|
-
def get_generator(self, device: torch.types.Device, seed: int | None):
|
|
320
|
-
if seed is None: return None
|
|
321
|
-
|
|
322
|
-
if 'generator' not in self.global_state:
|
|
323
|
-
self.global_state['generator'] = torch.Generator(device).manual_seed(seed)
|
|
324
|
-
|
|
325
|
-
return self.global_state['generator']
|
|
326
406
|
|
|
327
407
|
Chainable = Module | Sequence[Module]
|