torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -7,9 +7,16 @@ from typing import Any, Literal, Protocol, cast, final, overload
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
from ...core import Chainable, Module,
|
|
11
|
-
from ...
|
|
12
|
-
from ...utils
|
|
10
|
+
from ...core import Chainable, Module, Objective
|
|
11
|
+
from ...linalg.linear_operator import LinearOperator
|
|
12
|
+
from ...utils import (
|
|
13
|
+
TensorList,
|
|
14
|
+
generic_finfo,
|
|
15
|
+
generic_vector_norm,
|
|
16
|
+
safe_dict_update_,
|
|
17
|
+
tofloat,
|
|
18
|
+
vec_to_tensors,
|
|
19
|
+
)
|
|
13
20
|
|
|
14
21
|
|
|
15
22
|
def _flatten_tensors(tensors: list[torch.Tensor]):
|
|
@@ -256,24 +263,24 @@ class TrustRegionBase(Module, ABC):
|
|
|
256
263
|
"""Solve Hx=g with a trust region penalty/bound defined by `radius`"""
|
|
257
264
|
... # pylint:disable=unnecessary-ellipsis
|
|
258
265
|
|
|
259
|
-
def trust_region_update(self,
|
|
266
|
+
def trust_region_update(self, objective: Objective, H: LinearOperator | None) -> None:
|
|
260
267
|
"""updates the state of this module after H or B have been updated, if necessary"""
|
|
261
268
|
|
|
262
|
-
def trust_region_apply(self,
|
|
263
|
-
"""Solves the trust region subproblem and outputs ``
|
|
269
|
+
def trust_region_apply(self, objective: Objective, tensors:list[torch.Tensor], H: LinearOperator | None) -> Objective:
|
|
270
|
+
"""Solves the trust region subproblem and outputs ``Objective`` with the solution direction."""
|
|
264
271
|
assert H is not None
|
|
265
272
|
|
|
266
|
-
params = TensorList(
|
|
273
|
+
params = TensorList(objective.params)
|
|
267
274
|
settings = self.settings[params[0]]
|
|
268
275
|
g = _flatten_tensors(tensors)
|
|
269
276
|
|
|
270
277
|
max_attempts = settings['max_attempts']
|
|
271
278
|
|
|
272
279
|
# loss at x_0
|
|
273
|
-
loss =
|
|
274
|
-
closure =
|
|
280
|
+
loss = objective.loss
|
|
281
|
+
closure = objective.closure
|
|
275
282
|
if closure is None: raise RuntimeError("Trust region requires closure")
|
|
276
|
-
if loss is None: loss =
|
|
283
|
+
if loss is None: loss = objective.get_loss(False)
|
|
277
284
|
loss = tofloat(loss)
|
|
278
285
|
|
|
279
286
|
# trust region step and update
|
|
@@ -313,38 +320,36 @@ class TrustRegionBase(Module, ABC):
|
|
|
313
320
|
)
|
|
314
321
|
|
|
315
322
|
assert d is not None
|
|
316
|
-
if success:
|
|
317
|
-
else:
|
|
323
|
+
if success: objective.updates = vec_to_tensors(d, params)
|
|
324
|
+
else: objective.updates = params.zeros_like()
|
|
318
325
|
|
|
319
|
-
return
|
|
326
|
+
return objective
|
|
320
327
|
|
|
321
328
|
|
|
322
329
|
@final
|
|
323
330
|
@torch.no_grad
|
|
324
|
-
def update(self,
|
|
331
|
+
def update(self, objective):
|
|
325
332
|
step = self.global_state.get('step', 0)
|
|
326
333
|
self.global_state['step'] = step + 1
|
|
327
334
|
|
|
328
335
|
if step % self.defaults["update_freq"] == 0:
|
|
329
336
|
|
|
330
337
|
hessian_module = self.children['hess_module']
|
|
331
|
-
hessian_module.update(
|
|
332
|
-
H = hessian_module.get_H(
|
|
338
|
+
hessian_module.update(objective)
|
|
339
|
+
H = hessian_module.get_H(objective)
|
|
333
340
|
self.global_state["H"] = H
|
|
334
341
|
|
|
335
|
-
self.trust_region_update(
|
|
342
|
+
self.trust_region_update(objective, H=H)
|
|
336
343
|
|
|
337
344
|
|
|
338
345
|
@final
|
|
339
346
|
@torch.no_grad
|
|
340
|
-
def apply(self,
|
|
347
|
+
def apply(self, objective):
|
|
341
348
|
H = self.global_state.get('H', None)
|
|
342
349
|
|
|
343
350
|
# -------------------------------- inner step -------------------------------- #
|
|
344
|
-
|
|
345
|
-
if 'inner' in self.children:
|
|
346
|
-
update = apply_transform(self.children['inner'], update, params=var.params, grads=var.grad, var=var)
|
|
351
|
+
objective = self.inner_step("inner", objective, must_exist=False)
|
|
347
352
|
|
|
348
353
|
# ----------------------------------- apply ---------------------------------- #
|
|
349
|
-
return self.trust_region_apply(
|
|
354
|
+
return self.trust_region_apply(objective=objective, tensors=objective.get_updates(), H=H)
|
|
350
355
|
|
|
@@ -3,15 +3,17 @@ from functools import partial
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core
|
|
6
|
+
from ...core import Module, Objective
|
|
7
7
|
from ...utils import tofloat
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
def _reset_except_self(
|
|
11
|
-
|
|
10
|
+
def _reset_except_self(objective: Objective, modules, self: Module):
|
|
11
|
+
assert objective.modular is not None
|
|
12
|
+
for m in objective.modular.flat_modules:
|
|
12
13
|
if m is not self:
|
|
13
14
|
m.reset()
|
|
14
15
|
|
|
16
|
+
|
|
15
17
|
class SVRG(Module):
|
|
16
18
|
"""Stochastic variance reduced gradient method (SVRG).
|
|
17
19
|
|
|
@@ -71,7 +73,7 @@ class SVRG(Module):
|
|
|
71
73
|
```
|
|
72
74
|
## Notes
|
|
73
75
|
|
|
74
|
-
The SVRG gradient is computed as ``g_b(x) - alpha * g_b(x_0) - g_f(
|
|
76
|
+
The SVRG gradient is computed as ``g_b(x) - alpha * (g_b(x_0) - g_f(x_0))``, where:
|
|
75
77
|
- ``x`` is current parameters
|
|
76
78
|
- ``x_0`` is initial parameters, where full gradient was computed
|
|
77
79
|
- ``g_b`` refers to mini-batch gradient at ``x`` or ``x_0``
|
|
@@ -83,17 +85,18 @@ class SVRG(Module):
|
|
|
83
85
|
defaults = dict(svrg_steps = svrg_steps, accum_steps=accum_steps, reset_before_accum=reset_before_accum, svrg_loss=svrg_loss, alpha=alpha)
|
|
84
86
|
super().__init__(defaults)
|
|
85
87
|
|
|
88
|
+
|
|
86
89
|
@torch.no_grad
|
|
87
|
-
def
|
|
88
|
-
params =
|
|
89
|
-
closure =
|
|
90
|
+
def update(self, objective):
|
|
91
|
+
params = objective.params
|
|
92
|
+
closure = objective.closure
|
|
90
93
|
assert closure is not None
|
|
91
94
|
|
|
92
95
|
if "full_grad" not in self.global_state:
|
|
93
96
|
|
|
94
97
|
# -------------------------- calculate full gradient ------------------------- #
|
|
95
|
-
if "full_closure" in
|
|
96
|
-
full_closure =
|
|
98
|
+
if "full_closure" in objective.storage:
|
|
99
|
+
full_closure = objective.storage['full_closure']
|
|
97
100
|
with torch.enable_grad():
|
|
98
101
|
full_loss = full_closure()
|
|
99
102
|
if all(p.grad is None for p in params):
|
|
@@ -116,12 +119,12 @@ class SVRG(Module):
|
|
|
116
119
|
|
|
117
120
|
# accumulate grads
|
|
118
121
|
accumulator = self.get_state(params, 'accumulator')
|
|
119
|
-
grad =
|
|
122
|
+
grad = objective.get_grads()
|
|
120
123
|
torch._foreach_add_(accumulator, grad)
|
|
121
124
|
|
|
122
125
|
# accumulate loss
|
|
123
126
|
loss_accumulator = self.global_state.get('loss_accumulator', 0)
|
|
124
|
-
loss_accumulator += tofloat(
|
|
127
|
+
loss_accumulator += tofloat(objective.loss)
|
|
125
128
|
self.global_state['loss_accumulator'] = loss_accumulator
|
|
126
129
|
|
|
127
130
|
# on nth step, use the accumulated gradient
|
|
@@ -136,10 +139,10 @@ class SVRG(Module):
|
|
|
136
139
|
|
|
137
140
|
# otherwise skip update until enough grads are accumulated
|
|
138
141
|
else:
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
return
|
|
142
|
+
objective.updates = None
|
|
143
|
+
objective.stop = True
|
|
144
|
+
objective.skip_update = True
|
|
145
|
+
return
|
|
143
146
|
|
|
144
147
|
|
|
145
148
|
svrg_steps = self.defaults['svrg_steps']
|
|
@@ -194,7 +197,7 @@ class SVRG(Module):
|
|
|
194
197
|
|
|
195
198
|
return closure(False)
|
|
196
199
|
|
|
197
|
-
|
|
200
|
+
objective.closure = svrg_closure
|
|
198
201
|
|
|
199
202
|
# --- after svrg_steps steps reset so that new full gradient is calculated on next step --- #
|
|
200
203
|
if current_svrg_step >= svrg_steps:
|
|
@@ -203,6 +206,6 @@ class SVRG(Module):
|
|
|
203
206
|
del self.global_state['full_loss']
|
|
204
207
|
del self.global_state['x_0']
|
|
205
208
|
if self.defaults['reset_before_accum']:
|
|
206
|
-
|
|
209
|
+
objective.post_step_hooks.append(partial(_reset_except_self, self=self))
|
|
207
210
|
|
|
208
|
-
|
|
211
|
+
def apply(self, objective): return objective
|
|
@@ -1 +1,2 @@
|
|
|
1
|
-
from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, RelativeWeightDecay
|
|
1
|
+
from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, RelativeWeightDecay
|
|
2
|
+
from .reinit import RandomReinitialize
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Module
|
|
6
|
+
from ...utils import NumberList, TensorList
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _reset_except_self(optimizer, var, self: Module):
|
|
10
|
+
for m in optimizer.unrolled_modules:
|
|
11
|
+
if m is not self:
|
|
12
|
+
m.reset()
|
|
13
|
+
|
|
14
|
+
class RandomReinitialize(Module):
|
|
15
|
+
"""On each step with probability ``p_reinit`` trigger reinitialization,
|
|
16
|
+
whereby ``p_weights`` weights are reset to their initial values.
|
|
17
|
+
|
|
18
|
+
This modifies the parameters directly. Place it as the first module.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
p_reinit (float, optional): probability to trigger reinitialization on each step. Defaults to 0.01.
|
|
22
|
+
p_weights (float, optional): probability for each weight to be set to initial value when reinitialization is triggered. Defaults to 0.1.
|
|
23
|
+
store_every (int | None, optional): if set, stores new initial values every this many steps. Defaults to None.
|
|
24
|
+
beta (float, optional):
|
|
25
|
+
whenever ``store_every`` is triggered, uses linear interpolation with this beta.
|
|
26
|
+
If ``store_every=1``, this can be set to some value close to 1 such as 0.999
|
|
27
|
+
to reinitialize to slow parameter EMA. Defaults to 0.
|
|
28
|
+
reset (bool, optional): whether to reset states of other modules on reinitialization. Defaults to False.
|
|
29
|
+
seed (int | None, optional): random seed.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
p_reinit: float = 0.01,
|
|
35
|
+
p_weights: float = 0.1,
|
|
36
|
+
store_every: int | None = None,
|
|
37
|
+
beta: float = 0,
|
|
38
|
+
reset: bool = False,
|
|
39
|
+
seed: int | None = None,
|
|
40
|
+
):
|
|
41
|
+
defaults = dict(p_weights=p_weights, p_reinit=p_reinit, store_every=store_every, beta=beta, reset=reset, seed=seed)
|
|
42
|
+
super().__init__(defaults)
|
|
43
|
+
|
|
44
|
+
def update(self, objective):
|
|
45
|
+
# this stores initial values to per-parameter states
|
|
46
|
+
p_init = self.get_state(objective.params, "p_init", init="params", cls=TensorList)
|
|
47
|
+
|
|
48
|
+
# store new params every store_every steps
|
|
49
|
+
step = self.global_state.get("step", 0)
|
|
50
|
+
self.global_state["step"] = step + 1
|
|
51
|
+
|
|
52
|
+
store_every = self.defaults["store_every"]
|
|
53
|
+
if (store_every is not None and step % store_every == 0):
|
|
54
|
+
beta = self.get_settings(objective.params, "beta", cls=NumberList)
|
|
55
|
+
p_init.lerp_(objective.params, weight=(1 - beta))
|
|
56
|
+
|
|
57
|
+
@torch.no_grad
|
|
58
|
+
def apply(self, objective):
|
|
59
|
+
p_reinit = self.defaults["p_reinit"]
|
|
60
|
+
device = objective.params[0].device
|
|
61
|
+
generator = self.get_generator(device, self.defaults["seed"])
|
|
62
|
+
|
|
63
|
+
# determine whether to trigger reinitialization
|
|
64
|
+
reinitialize = torch.rand(1, generator=generator, device=device) < p_reinit
|
|
65
|
+
|
|
66
|
+
# reinitialize
|
|
67
|
+
if reinitialize:
|
|
68
|
+
params = TensorList(objective.params)
|
|
69
|
+
p_init = self.get_state(params, "p_init", init=params)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# mask with p_weights entries being True
|
|
73
|
+
p_weights = self.get_settings(params, "p_weights")
|
|
74
|
+
mask = params.bernoulli_like(p_weights, generator=generator).as_bool()
|
|
75
|
+
|
|
76
|
+
# set weights at mask to their initialization
|
|
77
|
+
params.masked_set_(mask, p_init)
|
|
78
|
+
|
|
79
|
+
# reset
|
|
80
|
+
if self.defaults["reset"]:
|
|
81
|
+
objective.post_step_hooks.append(partial(_reset_except_self, self=self))
|
|
82
|
+
|
|
83
|
+
return objective
|
|
@@ -3,7 +3,7 @@ from typing import Literal
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Module,
|
|
6
|
+
from ...core import Module, TensorTransform
|
|
7
7
|
from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states, Metrics
|
|
8
8
|
|
|
9
9
|
|
|
@@ -21,7 +21,7 @@ def weight_decay_(
|
|
|
21
21
|
return grad_.add_(params.pow(ord-1).copysign_(params).mul_(weight_decay))
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
class WeightDecay(
|
|
24
|
+
class WeightDecay(TensorTransform):
|
|
25
25
|
"""Weight decay.
|
|
26
26
|
|
|
27
27
|
Args:
|
|
@@ -63,19 +63,19 @@ class WeightDecay(Transform):
|
|
|
63
63
|
```
|
|
64
64
|
|
|
65
65
|
"""
|
|
66
|
-
def __init__(self, weight_decay: float, ord: int = 2
|
|
66
|
+
def __init__(self, weight_decay: float, ord: int = 2):
|
|
67
67
|
|
|
68
68
|
defaults = dict(weight_decay=weight_decay, ord=ord)
|
|
69
|
-
super().__init__(defaults
|
|
69
|
+
super().__init__(defaults)
|
|
70
70
|
|
|
71
71
|
@torch.no_grad
|
|
72
|
-
def
|
|
72
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
73
73
|
weight_decay = NumberList(s['weight_decay'] for s in settings)
|
|
74
74
|
ord = settings[0]['ord']
|
|
75
75
|
|
|
76
76
|
return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)
|
|
77
77
|
|
|
78
|
-
class RelativeWeightDecay(
|
|
78
|
+
class RelativeWeightDecay(TensorTransform):
|
|
79
79
|
"""Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of ``norm_input`` argument.
|
|
80
80
|
|
|
81
81
|
Args:
|
|
@@ -117,13 +117,12 @@ class RelativeWeightDecay(Transform):
|
|
|
117
117
|
ord: int = 2,
|
|
118
118
|
norm_input: Literal["update", "grad", "params"] = "update",
|
|
119
119
|
metric: Metrics = 'mad',
|
|
120
|
-
target: Target = "update",
|
|
121
120
|
):
|
|
122
121
|
defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input, metric=metric)
|
|
123
|
-
super().__init__(defaults, uses_grad=norm_input == 'grad'
|
|
122
|
+
super().__init__(defaults, uses_grad=norm_input == 'grad')
|
|
124
123
|
|
|
125
124
|
@torch.no_grad
|
|
126
|
-
def
|
|
125
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
127
126
|
weight_decay = NumberList(s['weight_decay'] for s in settings)
|
|
128
127
|
|
|
129
128
|
ord = settings[0]['ord']
|
|
@@ -161,9 +160,9 @@ class DirectWeightDecay(Module):
|
|
|
161
160
|
super().__init__(defaults)
|
|
162
161
|
|
|
163
162
|
@torch.no_grad
|
|
164
|
-
def
|
|
165
|
-
weight_decay = self.get_settings(
|
|
163
|
+
def apply(self, objective):
|
|
164
|
+
weight_decay = self.get_settings(objective.params, 'weight_decay', cls=NumberList)
|
|
166
165
|
ord = self.defaults['ord']
|
|
167
166
|
|
|
168
|
-
decay_weights_(
|
|
169
|
-
return
|
|
167
|
+
decay_weights_(objective.params, weight_decay, ord)
|
|
168
|
+
return objective
|
|
@@ -3,41 +3,55 @@ from typing import Any
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from ...core.module import Module
|
|
6
|
-
from ...utils import Params, _copy_param_groups, _make_param_groups
|
|
6
|
+
from ...utils.params import Params, _copy_param_groups, _make_param_groups
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class Wrap(Module):
|
|
10
10
|
"""
|
|
11
11
|
Wraps a pytorch optimizer to use it as a module.
|
|
12
12
|
|
|
13
|
-
|
|
14
|
-
Custom param groups are supported only by
|
|
13
|
+
Note:
|
|
14
|
+
Custom param groups are supported only by ``set_param_groups``, settings passed to Modular will be applied to all parameters.
|
|
15
15
|
|
|
16
16
|
Args:
|
|
17
17
|
opt_fn (Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer):
|
|
18
|
-
function that takes in parameters and returns the optimizer, for example
|
|
19
|
-
or
|
|
18
|
+
function that takes in parameters and returns the optimizer, for example ``torch.optim.Adam``
|
|
19
|
+
or ``lambda parameters: torch.optim.Adam(parameters, lr=1e-3)``
|
|
20
20
|
*args:
|
|
21
21
|
**kwargs:
|
|
22
|
-
Extra args to be passed to opt_fn. The function is called as
|
|
22
|
+
Extra args to be passed to opt_fn. The function is called as ``opt_fn(parameters, *args, **kwargs)``.
|
|
23
|
+
use_param_groups:
|
|
24
|
+
Whether to pass settings passed to Modular to the wrapped optimizer.
|
|
23
25
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
+
Note that settings to the first parameter are used for all parameters,
|
|
27
|
+
so if you specified per-parameter settings, they will be ignored.
|
|
26
28
|
|
|
27
|
-
|
|
29
|
+
### Example:
|
|
30
|
+
wrapping pytorch_optimizer.StableAdamW
|
|
28
31
|
|
|
29
|
-
|
|
30
|
-
opt = tz.Modular(
|
|
31
|
-
model.parameters(),
|
|
32
|
-
tz.m.Wrap(StableAdamW, lr=1),
|
|
33
|
-
tz.m.Cautious(),
|
|
34
|
-
tz.m.LR(1e-2)
|
|
35
|
-
)
|
|
32
|
+
```python
|
|
36
33
|
|
|
34
|
+
from pytorch_optimizer import StableAdamW
|
|
35
|
+
opt = tz.Modular(
|
|
36
|
+
model.parameters(),
|
|
37
|
+
tz.m.Wrap(StableAdamW, lr=1),
|
|
38
|
+
tz.m.Cautious(),
|
|
39
|
+
tz.m.LR(1e-2)
|
|
40
|
+
)
|
|
41
|
+
```
|
|
37
42
|
|
|
38
43
|
"""
|
|
39
|
-
|
|
40
|
-
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer,
|
|
48
|
+
*args,
|
|
49
|
+
use_param_groups: bool = True,
|
|
50
|
+
**kwargs,
|
|
51
|
+
):
|
|
52
|
+
defaults = dict(use_param_groups=use_param_groups)
|
|
53
|
+
super().__init__(defaults=defaults)
|
|
54
|
+
|
|
41
55
|
self._opt_fn = opt_fn
|
|
42
56
|
self._opt_args = args
|
|
43
57
|
self._opt_kwargs = kwargs
|
|
@@ -48,12 +62,12 @@ class Wrap(Module):
|
|
|
48
62
|
self.optimizer = self._opt_fn
|
|
49
63
|
|
|
50
64
|
def set_param_groups(self, param_groups):
|
|
51
|
-
self._custom_param_groups = param_groups
|
|
65
|
+
self._custom_param_groups = _make_param_groups(param_groups, differentiable=False)
|
|
52
66
|
return super().set_param_groups(param_groups)
|
|
53
67
|
|
|
54
68
|
@torch.no_grad
|
|
55
|
-
def
|
|
56
|
-
params =
|
|
69
|
+
def apply(self, objective):
|
|
70
|
+
params = objective.params
|
|
57
71
|
|
|
58
72
|
# initialize opt on 1st step
|
|
59
73
|
if self.optimizer is None:
|
|
@@ -61,54 +75,47 @@ class Wrap(Module):
|
|
|
61
75
|
param_groups = params if self._custom_param_groups is None else self._custom_param_groups
|
|
62
76
|
self.optimizer = self._opt_fn(param_groups, *self._opt_args, **self._opt_kwargs)
|
|
63
77
|
|
|
78
|
+
# set optimizer per-parameter settings
|
|
79
|
+
if self.defaults["use_param_groups"] and objective.modular is not None:
|
|
80
|
+
for group in self.optimizer.param_groups:
|
|
81
|
+
first_param = group['params'][0]
|
|
82
|
+
setting = self.settings[first_param]
|
|
83
|
+
|
|
84
|
+
# settings passed in `set_param_groups` are the highest priority
|
|
85
|
+
# schedulers will override defaults but not settings passed in `set_param_groups`
|
|
86
|
+
# this is consistent with how Modular does it.
|
|
87
|
+
if self._custom_param_groups is not None:
|
|
88
|
+
setting = {k:v for k,v in setting if k not in self._custom_param_groups[0]}
|
|
89
|
+
|
|
90
|
+
group.update(setting)
|
|
91
|
+
|
|
64
92
|
# set grad to update
|
|
65
93
|
orig_grad = [p.grad for p in params]
|
|
66
|
-
for p, u in zip(params,
|
|
94
|
+
for p, u in zip(params, objective.get_updates()):
|
|
67
95
|
p.grad = u
|
|
68
96
|
|
|
69
|
-
# if this
|
|
70
|
-
|
|
71
|
-
# and if there are multiple different per-parameter lrs (would be annoying to support)
|
|
72
|
-
if var.is_last and (
|
|
73
|
-
(var.last_module_lrs is None)
|
|
74
|
-
or
|
|
75
|
-
(('lr' in self.optimizer.defaults) and (len(set(var.last_module_lrs)) == 1))
|
|
76
|
-
):
|
|
77
|
-
lr = 1 if var.last_module_lrs is None else var.last_module_lrs[0]
|
|
78
|
-
|
|
79
|
-
# update optimizer lr with desired lr
|
|
80
|
-
if lr != 1:
|
|
81
|
-
self.optimizer.defaults['__original_lr__'] = self.optimizer.defaults['lr']
|
|
82
|
-
for g in self.optimizer.param_groups:
|
|
83
|
-
g['__original_lr__'] = g['lr']
|
|
84
|
-
g['lr'] = g['lr'] * lr
|
|
85
|
-
|
|
86
|
-
# step
|
|
97
|
+
# if this is last module, simply use optimizer to update parameters
|
|
98
|
+
if objective.modular is not None and self is objective.modular.modules[-1]:
|
|
87
99
|
self.optimizer.step()
|
|
88
100
|
|
|
89
|
-
# restore original lr
|
|
90
|
-
if lr != 1:
|
|
91
|
-
self.optimizer.defaults['lr'] = self.optimizer.defaults.pop('__original_lr__')
|
|
92
|
-
for g in self.optimizer.param_groups:
|
|
93
|
-
g['lr'] = g.pop('__original_lr__')
|
|
94
|
-
|
|
95
101
|
# restore grad
|
|
96
102
|
for p, g in zip(params, orig_grad):
|
|
97
103
|
p.grad = g
|
|
98
104
|
|
|
99
|
-
|
|
100
|
-
return
|
|
105
|
+
objective.stop = True; objective.skip_update = True
|
|
106
|
+
return objective
|
|
101
107
|
|
|
102
108
|
# this is not the last module, meaning update is difference in parameters
|
|
109
|
+
# and passed to next module
|
|
103
110
|
params_before_step = [p.clone() for p in params]
|
|
104
111
|
self.optimizer.step() # step and update params
|
|
105
112
|
for p, g in zip(params, orig_grad):
|
|
106
113
|
p.grad = g
|
|
107
|
-
|
|
114
|
+
objective.updates = list(torch._foreach_sub(params_before_step, params)) # set update to difference between params
|
|
108
115
|
for p, o in zip(params, params_before_step):
|
|
109
116
|
p.set_(o) # pyright: ignore[reportArgumentType]
|
|
110
117
|
|
|
111
|
-
return
|
|
118
|
+
return objective
|
|
112
119
|
|
|
113
120
|
def reset(self):
|
|
114
121
|
super().reset()
|
|
@@ -33,13 +33,16 @@ class CD(Module):
|
|
|
33
33
|
defaults = dict(h=h, grad=grad, adaptive=adaptive, index=index, threepoint=threepoint)
|
|
34
34
|
super().__init__(defaults)
|
|
35
35
|
|
|
36
|
+
def update(self, objective): raise RuntimeError
|
|
37
|
+
def apply(self, objective): raise RuntimeError
|
|
38
|
+
|
|
36
39
|
@torch.no_grad
|
|
37
|
-
def step(self,
|
|
38
|
-
closure =
|
|
40
|
+
def step(self, objective):
|
|
41
|
+
closure = objective.closure
|
|
39
42
|
if closure is None:
|
|
40
43
|
raise RuntimeError("CD requires closure")
|
|
41
44
|
|
|
42
|
-
params = TensorList(
|
|
45
|
+
params = TensorList(objective.params)
|
|
43
46
|
ndim = params.global_numel()
|
|
44
47
|
|
|
45
48
|
grad_step_size = self.defaults['grad']
|
|
@@ -79,7 +82,7 @@ class CD(Module):
|
|
|
79
82
|
else:
|
|
80
83
|
warnings.warn("CD adaptive=True only works with threepoint=True")
|
|
81
84
|
|
|
82
|
-
f_0 =
|
|
85
|
+
f_0 = objective.get_loss(False)
|
|
83
86
|
params.flat_set_lambda_(idx, lambda x: x + h)
|
|
84
87
|
f_p = closure(False)
|
|
85
88
|
|
|
@@ -117,6 +120,6 @@ class CD(Module):
|
|
|
117
120
|
# ----------------------------- create the update ---------------------------- #
|
|
118
121
|
update = params.zeros_like()
|
|
119
122
|
update.flat_set_(idx, alpha)
|
|
120
|
-
|
|
121
|
-
return
|
|
123
|
+
objective.updates = update
|
|
124
|
+
return objective
|
|
122
125
|
|
torchzero/optim/root.py
CHANGED
|
@@ -3,7 +3,7 @@ from collections.abc import Callable
|
|
|
3
3
|
|
|
4
4
|
from abc import abstractmethod
|
|
5
5
|
import torch
|
|
6
|
-
from ..modules.
|
|
6
|
+
from ..modules.second_order.multipoint import sixth_order_3p, sixth_order_5p, two_point_newton, sixth_order_3pm2, _solve
|
|
7
7
|
|
|
8
8
|
def make_evaluate(f: Callable[[torch.Tensor], torch.Tensor]):
|
|
9
9
|
def evaluate(x, order) -> tuple[torch.Tensor, ...]:
|
|
@@ -53,7 +53,7 @@ class Newton(RootBase):
|
|
|
53
53
|
def one_iteration(self, x, evaluate): return newton(x, evaluate, self.lstsq)
|
|
54
54
|
|
|
55
55
|
|
|
56
|
-
class
|
|
56
|
+
class SixthOrder3P(RootBase):
|
|
57
57
|
"""sixth-order iterative method
|
|
58
58
|
|
|
59
59
|
Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
|
|
@@ -62,4 +62,4 @@ class SixthOrderP6(RootBase):
|
|
|
62
62
|
def one_iteration(self, x, evaluate):
|
|
63
63
|
def f(x): return evaluate(x, 0)[0]
|
|
64
64
|
def f_j(x): return evaluate(x, 1)
|
|
65
|
-
return
|
|
65
|
+
return sixth_order_3p(x, f, f_j, self.lstsq)
|
torchzero/optim/utility/split.py
CHANGED
|
@@ -3,7 +3,8 @@ from collections.abc import Callable, Iterable
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...utils import flatten
|
|
6
|
+
from ...utils import flatten
|
|
7
|
+
from ...utils.optimizer import get_params
|
|
7
8
|
|
|
8
9
|
class Split(torch.optim.Optimizer):
|
|
9
10
|
"""Steps will all `optimizers`, also has a check that they have no duplicate parameters.
|