torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
torchzero/core/module.py
CHANGED
|
@@ -1,24 +1,18 @@
|
|
|
1
1
|
import warnings
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from collections import ChainMap, defaultdict
|
|
4
|
-
from collections.abc import Callable, Iterable,
|
|
5
|
-
from
|
|
6
|
-
from typing import Any, Literal, cast, final, overload
|
|
4
|
+
from collections.abc import Callable, Iterable, Sequence
|
|
5
|
+
from typing import Any, overload, TYPE_CHECKING
|
|
7
6
|
|
|
8
7
|
import torch
|
|
9
8
|
|
|
10
|
-
from ..
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
)
|
|
18
|
-
from ..utils.derivatives import flatten_jacobian, hvp, hvp_fd_central, hvp_fd_forward
|
|
19
|
-
from ..utils.linalg.linear_operator import LinearOperator
|
|
20
|
-
from ..utils.python_tools import flatten
|
|
21
|
-
from .var import Var
|
|
9
|
+
from ..linalg.linear_operator import LinearOperator
|
|
10
|
+
from ..utils.optimizer import Init, ListLike, get_state_vals
|
|
11
|
+
from ..utils.params import Params, _make_param_groups
|
|
12
|
+
from .functional import step_tensors
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from .objective import Objective
|
|
22
16
|
|
|
23
17
|
|
|
24
18
|
class Module(ABC):
|
|
@@ -36,11 +30,12 @@ class Module(ABC):
|
|
|
36
30
|
"""
|
|
37
31
|
def __init__(self, defaults: dict[str, Any] | None = None):
|
|
38
32
|
if defaults is None: defaults = {}
|
|
33
|
+
if any(isinstance(v, Module) for v in defaults.values()): raise RuntimeError("Passed a module to defaults")
|
|
39
34
|
self.defaults: dict[str, Any] = defaults
|
|
40
35
|
|
|
41
36
|
# settings are stored like state in per-tensor defaultdict, with per-parameter overrides possible
|
|
42
37
|
# 0 - this module specific per-parameter setting overrides set via `set_param_groups` - highest priority
|
|
43
|
-
# 1 - global per-parameter setting overrides in param_groups passed to
|
|
38
|
+
# 1 - global per-parameter setting overrides in param_groups passed to Optimizer - medium priority
|
|
44
39
|
# 2 - `defaults` - lowest priority
|
|
45
40
|
self.settings: defaultdict[torch.Tensor, ChainMap[str, Any]] = defaultdict(lambda: ChainMap({}, {}, self.defaults))
|
|
46
41
|
"""per-parameter settings."""
|
|
@@ -55,7 +50,7 @@ class Module(ABC):
|
|
|
55
50
|
"""A dictionary of child modules."""
|
|
56
51
|
|
|
57
52
|
self._overridden_keys = set()
|
|
58
|
-
"""tracks keys overridden with
|
|
53
|
+
"""tracks keys overridden with ``set_param_groups``, only used to not give a warning"""
|
|
59
54
|
|
|
60
55
|
|
|
61
56
|
def set_param_groups(self, param_groups: Params):
|
|
@@ -71,7 +66,12 @@ class Module(ABC):
|
|
|
71
66
|
self.settings[param].maps[0].update(settings) # set module-specific per-parameter settings
|
|
72
67
|
return self
|
|
73
68
|
|
|
74
|
-
def set_child(self, key: str, module: "Module | Sequence[Module]"):
|
|
69
|
+
def set_child(self, key: str, module: "Module | Sequence[Module] | None"):
|
|
70
|
+
if key in self.children:
|
|
71
|
+
warnings.warn(f"set_child overwriting child `{key}`")
|
|
72
|
+
|
|
73
|
+
if module is None: return
|
|
74
|
+
|
|
75
75
|
from .chain import maybe_chain
|
|
76
76
|
self.children[key] = maybe_chain(module)
|
|
77
77
|
|
|
@@ -85,6 +85,62 @@ class Module(ABC):
|
|
|
85
85
|
def get_children_sequence(self, prefix = 'module_'):
|
|
86
86
|
return [self.children[f'{prefix}{i}'] for i in range(len(self.children)) if f'{prefix}{i}' in self.children]
|
|
87
87
|
|
|
88
|
+
def inner_step(
|
|
89
|
+
self,
|
|
90
|
+
key: str,
|
|
91
|
+
objective: "Objective",
|
|
92
|
+
must_exist: bool = True,
|
|
93
|
+
) -> "Objective":
|
|
94
|
+
"""Passes ``objective`` to child and returns it."""
|
|
95
|
+
child = self.children.get(key, None)
|
|
96
|
+
|
|
97
|
+
if child is None:
|
|
98
|
+
if must_exist: raise KeyError(f"child `{key}` doesn't exist")
|
|
99
|
+
return objective
|
|
100
|
+
|
|
101
|
+
return child.step(objective)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def inner_step_tensors(
|
|
105
|
+
self,
|
|
106
|
+
key: str,
|
|
107
|
+
tensors: list[torch.Tensor],
|
|
108
|
+
clone: bool,
|
|
109
|
+
params: Iterable[torch.Tensor] | None = None,
|
|
110
|
+
grads: Sequence[torch.Tensor] | None = None,
|
|
111
|
+
loss: torch.Tensor | None = None,
|
|
112
|
+
closure: Callable | None = None,
|
|
113
|
+
objective: "Objective | None" = None,
|
|
114
|
+
must_exist: bool = True
|
|
115
|
+
) -> list[torch.Tensor]:
|
|
116
|
+
"""Steps with child module. Can be used to apply transforms to any internal buffers.
|
|
117
|
+
|
|
118
|
+
If ``objective`` is specified, other attributes shouldn't to be specified.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
key (str): Child module key.
|
|
122
|
+
tensors (Sequence[torch.Tensor]): tensors to pass to child module.
|
|
123
|
+
clone (bool):
|
|
124
|
+
If ``key`` exists, whether to clone ``tensors`` to avoid modifying buffers in-place.
|
|
125
|
+
If ``key`` doesn't exist, ``tensors`` are always returned without cloning
|
|
126
|
+
params (Iterable[torch.Tensor] | None, optional): pass None if ``tensors`` have different shape. Defaults to None.
|
|
127
|
+
grads (Sequence[torch.Tensor] | None, optional): grads. Defaults to None.
|
|
128
|
+
loss (torch.Tensor | None, optional): loss. Defaults to None.
|
|
129
|
+
closure (Callable | None, optional): closure. Defaults to None.
|
|
130
|
+
must_exist (bool, optional): if True, if ``key`` doesn't exist, raises ``KeyError``. Defaults to True.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
child = self.children.get(key, None)
|
|
134
|
+
|
|
135
|
+
if child is None:
|
|
136
|
+
if must_exist: raise KeyError(f"child `{key}` doesn't exist")
|
|
137
|
+
return tensors
|
|
138
|
+
|
|
139
|
+
if clone: tensors = [t.clone() for t in tensors]
|
|
140
|
+
return step_tensors(modules=child, tensors=tensors, params=params, grads=grads,
|
|
141
|
+
loss=loss, closure=closure, objective=objective)
|
|
142
|
+
|
|
143
|
+
|
|
88
144
|
def __repr__(self):
|
|
89
145
|
s = self.__class__.__name__
|
|
90
146
|
if self.children:
|
|
@@ -106,7 +162,6 @@ class Module(ABC):
|
|
|
106
162
|
|
|
107
163
|
def get_settings(self, params: Sequence[torch.Tensor], key: str | list[str] | tuple[str,...], key2: str | None = None,
|
|
108
164
|
*keys: str, cls: type[ListLike] = list) -> ListLike | list[ListLike]:
|
|
109
|
-
# if isinstance(params, Vars): params = params.params
|
|
110
165
|
return get_state_vals(self.settings, params, key, key2, *keys, must_exist=True, cls=cls) # pyright:ignore[reportArgumentType]
|
|
111
166
|
|
|
112
167
|
|
|
@@ -176,13 +231,8 @@ class Module(ABC):
|
|
|
176
231
|
- if state_keys has multiple keys and keys has a single key, return cls.
|
|
177
232
|
- if state_keys has multiple keys and keys has multiple keys, return list of cls.
|
|
178
233
|
"""
|
|
179
|
-
# if isinstance(params, Vars): params = params.params
|
|
180
234
|
return get_state_vals(self.state, params, key, key2, *keys, must_exist=must_exist, init=init, cls=cls) # pyright:ignore[reportArgumentType]
|
|
181
235
|
|
|
182
|
-
# def first_setting(self, *keys:str, params:Sequence[torch.Tensor]):
|
|
183
|
-
# # if isinstance(params, Vars): params = params.params
|
|
184
|
-
# return itemgetter(*keys)(self.settings[params[0]])
|
|
185
|
-
|
|
186
236
|
def clear_state_keys(self, *keys:str):
|
|
187
237
|
for s in self.state.values():
|
|
188
238
|
for k in keys:
|
|
@@ -223,7 +273,7 @@ class Module(ABC):
|
|
|
223
273
|
return state_dict
|
|
224
274
|
|
|
225
275
|
def _load_state_dict(self, state_dict: dict[str, Any], id_to_tensor: dict[int, torch.Tensor]):
|
|
226
|
-
"""loads state_dict, ``id_to_tensor`` is passed by ``
|
|
276
|
+
"""loads state_dict, ``id_to_tensor`` is passed by ``Optimizer``"""
|
|
227
277
|
# load state
|
|
228
278
|
state = state_dict['state']
|
|
229
279
|
self.state.clear()
|
|
@@ -248,36 +298,73 @@ class Module(ABC):
|
|
|
248
298
|
# extra info
|
|
249
299
|
self._extra_unpack(state_dict['extra'])
|
|
250
300
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
301
|
+
def get_generator(self, device: torch.types.Device, seed: int | None):
|
|
302
|
+
"""If ``seed=None``, returns ``None``.
|
|
303
|
+
|
|
304
|
+
Otherwise, if generator on this device and with this seed hasn't been created,
|
|
305
|
+
creates it and stores in global state.
|
|
306
|
+
|
|
307
|
+
Returns ``torch.Generator``."""
|
|
308
|
+
if seed is None: return None
|
|
256
309
|
|
|
257
|
-
|
|
258
|
-
|
|
310
|
+
if device is None: device_obj = torch.get_default_device()
|
|
311
|
+
else: device_obj = torch.device(device)
|
|
312
|
+
key = f"__generator-{seed}-{device_obj.type}:{device_obj.index}"
|
|
313
|
+
|
|
314
|
+
if key not in self.global_state:
|
|
315
|
+
self.global_state[key] = torch.Generator(device).manual_seed(seed)
|
|
316
|
+
|
|
317
|
+
return self.global_state[key]
|
|
318
|
+
|
|
319
|
+
def increment_counter(self, key: str, start: int):
|
|
320
|
+
"""first value is ``start``"""
|
|
321
|
+
value = self.global_state.get(key, start - 1) + 1
|
|
322
|
+
self.global_state[key] = value
|
|
323
|
+
return value
|
|
324
|
+
|
|
325
|
+
# ---------------------------- OVERRIDABLE METHODS --------------------------- #
|
|
326
|
+
def update(self, objective:"Objective") -> None:
|
|
327
|
+
"""Updates internal state of this module. This should not modify ``objective.update``.
|
|
259
328
|
|
|
260
329
|
Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
|
|
261
|
-
such as ``tz.m.Online`` or trust regions. Alternatively,
|
|
330
|
+
such as ``tz.m.Online`` or trust regions. Alternatively, define all logic within the ``apply`` method.
|
|
331
|
+
|
|
332
|
+
``update`` is guaranteed to be called at least once before ``apply``.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
objective (Objective): ``Objective`` object
|
|
262
336
|
"""
|
|
263
337
|
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
338
|
+
@abstractmethod
|
|
339
|
+
def apply(self, objective: "Objective") -> "Objective":
|
|
340
|
+
"""Updates ``objective`` using the internal state of this module.
|
|
341
|
+
|
|
342
|
+
If ``update`` method is defined, ``apply`` shouldn't modify the internal state of this module if possible.
|
|
267
343
|
|
|
268
344
|
Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
|
|
269
|
-
such as ``tz.m.Online`` or trust regions. Alternatively,
|
|
345
|
+
such as ``tz.m.Online`` or trust regions. Alternatively, define all logic within the ``apply`` method.
|
|
346
|
+
|
|
347
|
+
``update`` is guaranteed to be called at least once before ``apply``.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
objective (Objective): ``Objective`` object
|
|
270
351
|
"""
|
|
271
|
-
|
|
352
|
+
# if apply is empty, it should be defined explicitly.
|
|
353
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement `apply`.")
|
|
354
|
+
|
|
355
|
+
def step(self, objective: "Objective") -> "Objective":
|
|
356
|
+
"""Perform a step with this module. Calls ``update``, then ``apply``."""
|
|
357
|
+
self.update(objective)
|
|
358
|
+
return self.apply(objective)
|
|
272
359
|
|
|
273
|
-
def get_H(self,
|
|
360
|
+
def get_H(self, objective: "Objective") -> LinearOperator | None:
|
|
274
361
|
"""returns a ``LinearOperator`` corresponding to hessian or hessian approximation.
|
|
275
362
|
The hessian approximation is assumed to be for all parameters concatenated to a vector."""
|
|
276
363
|
# if this method is not defined it searches in children
|
|
277
364
|
# this should be overwritten to return None if child params are different from this modules params
|
|
278
365
|
H = None
|
|
279
366
|
for k,v in self.children.items():
|
|
280
|
-
H_v = v.get_H(
|
|
367
|
+
H_v = v.get_H(objective)
|
|
281
368
|
|
|
282
369
|
if (H is not None) and (H_v is not None):
|
|
283
370
|
raise RuntimeError(f"Two children of {self} have a hessian, second one is {k}={v}")
|
|
@@ -307,21 +394,14 @@ class Module(ABC):
|
|
|
307
394
|
"""
|
|
308
395
|
for c in self.children.values(): c.reset_for_online()
|
|
309
396
|
|
|
310
|
-
def _extra_pack(self):
|
|
311
|
-
"""extra information to store in state_dict of this optimizer.
|
|
312
|
-
Will be passed to ``_extra_unpack`` when loading the state_dict
|
|
397
|
+
def _extra_pack(self) -> dict:
|
|
398
|
+
"""extra information to store in ``state_dict`` of this optimizer.
|
|
399
|
+
Will be passed to ``_extra_unpack`` when loading the ``state_dict``."""
|
|
313
400
|
return {}
|
|
314
401
|
|
|
315
|
-
def _extra_unpack(self,
|
|
316
|
-
"""``_extra_pack`` return will be passed to this method when loading state_dict
|
|
402
|
+
def _extra_unpack(self, d: dict):
|
|
403
|
+
"""``_extra_pack`` return will be passed to this method when loading ``state_dict``.
|
|
317
404
|
This method is called after loading the rest of the state dict"""
|
|
318
405
|
|
|
319
|
-
def get_generator(self, device: torch.types.Device, seed: int | None):
|
|
320
|
-
if seed is None: return None
|
|
321
|
-
|
|
322
|
-
if 'generator' not in self.global_state:
|
|
323
|
-
self.global_state['generator'] = torch.Generator(device).manual_seed(seed)
|
|
324
|
-
|
|
325
|
-
return self.global_state['generator']
|
|
326
406
|
|
|
327
407
|
Chainable = Module | Sequence[Module]
|