torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +43 -33
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +48 -52
- torchzero/core/module.py +130 -50
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +2 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +15 -12
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +32 -32
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +50 -48
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +10 -10
- torchzero/modules/quasi_newton/sg2.py +19 -19
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +49 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +57 -90
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +157 -177
- torchzero/modules/second_order/rsn.py +106 -96
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +10 -10
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +93 -69
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
torchzero/core/reformulation.py
CHANGED
|
@@ -3,12 +3,20 @@ from collections.abc import Callable, Sequence
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from .chain import Chain
|
|
7
6
|
from .module import Chainable, Module
|
|
8
|
-
from .
|
|
7
|
+
from .objective import Objective
|
|
9
8
|
|
|
10
9
|
|
|
11
10
|
class Reformulation(Module, ABC):
|
|
11
|
+
"""Reformulation allows the definition of a new closure which returns custom loss and gradient.
|
|
12
|
+
|
|
13
|
+
If ``modules`` are passed, steps with those modules using the reformulated closure. Only ``step`` method is supported.
|
|
14
|
+
|
|
15
|
+
If ``modules`` is ``None``, sets new closure to the objective so that all further modules use it.
|
|
16
|
+
In that case make sure this method is first.
|
|
17
|
+
|
|
18
|
+
To use this, subclass and override ``closure`` and optionally ``pre_step``.
|
|
19
|
+
"""
|
|
12
20
|
def __init__(self, defaults: dict | None, modules: Chainable | None):
|
|
13
21
|
super().__init__(defaults)
|
|
14
22
|
|
|
@@ -16,30 +24,52 @@ class Reformulation(Module, ABC):
|
|
|
16
24
|
self.set_child("modules", modules)
|
|
17
25
|
|
|
18
26
|
@abstractmethod
|
|
19
|
-
def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor],
|
|
27
|
+
def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], objective: Objective) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
|
|
20
28
|
"""
|
|
21
|
-
returns (loss, gradient)
|
|
29
|
+
returns ``(loss, gradient)``, if backward is False then gradient can be None.
|
|
22
30
|
|
|
23
|
-
If evaluating original loss/gradient at
|
|
31
|
+
If evaluating original loss/gradient at ``x0``, set them to ``objective``.
|
|
24
32
|
"""
|
|
25
33
|
|
|
26
|
-
def pre_step(self,
|
|
27
|
-
"""This runs once before each step, whereas
|
|
34
|
+
def pre_step(self, objective: Objective):
|
|
35
|
+
"""This runs once before each step, whereas ``closure`` may run multiple times per step if further modules
|
|
28
36
|
evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
|
|
29
37
|
|
|
30
|
-
def
|
|
31
|
-
|
|
32
|
-
|
|
38
|
+
def update(self, objective):
|
|
39
|
+
if "modules" in self.children:
|
|
40
|
+
raise RuntimeError("Reformulation ({self.__class__.__name__} only supports `step` method if it has sub-modules.)")
|
|
41
|
+
|
|
42
|
+
self.pre_step(objective) # pylint:disable = assignment-from-no-return
|
|
43
|
+
|
|
44
|
+
if objective.closure is None: raise RuntimeError("Reformulation requires closure")
|
|
45
|
+
params, closure = objective.params, objective.closure # make sure to decouple from `objective` object
|
|
46
|
+
|
|
47
|
+
# define modified closure and set objective to use it
|
|
48
|
+
def modified_closure(backward=True):
|
|
49
|
+
loss, grad = self.closure(backward, closure, params, objective)
|
|
33
50
|
|
|
34
|
-
|
|
35
|
-
|
|
51
|
+
if grad is not None:
|
|
52
|
+
for p,g in zip(params, grad):
|
|
53
|
+
p.grad = g
|
|
54
|
+
|
|
55
|
+
return loss
|
|
56
|
+
|
|
57
|
+
objective.closure = modified_closure
|
|
58
|
+
|
|
59
|
+
def apply(self, objective): return objective
|
|
60
|
+
|
|
61
|
+
def step(self, objective):
|
|
36
62
|
|
|
37
|
-
# step with children
|
|
38
63
|
if 'modules' in self.children:
|
|
39
64
|
|
|
65
|
+
self.pre_step(objective) # pylint:disable = assignment-from-no-return
|
|
66
|
+
|
|
67
|
+
if objective.closure is None: raise RuntimeError("Reformulation requires closure")
|
|
68
|
+
params, closure = objective.params, objective.closure # make sure to decouple from `objective` object
|
|
69
|
+
|
|
40
70
|
# make a reformulated closure
|
|
41
71
|
def modified_closure(backward=True):
|
|
42
|
-
loss, grad = self.closure(backward, closure, params,
|
|
72
|
+
loss, grad = self.closure(backward, closure, params, objective)
|
|
43
73
|
|
|
44
74
|
if grad is not None:
|
|
45
75
|
for p,g in zip(params, grad):
|
|
@@ -47,21 +77,22 @@ class Reformulation(Module, ABC):
|
|
|
47
77
|
|
|
48
78
|
return loss
|
|
49
79
|
|
|
50
|
-
# set it to a new
|
|
51
|
-
|
|
52
|
-
|
|
80
|
+
# set it to a new Objective object
|
|
81
|
+
modified_objective = objective.clone(clone_updates=False)
|
|
82
|
+
modified_objective.closure = modified_closure
|
|
53
83
|
|
|
54
|
-
#
|
|
84
|
+
# update the child
|
|
55
85
|
modules = self.children['modules']
|
|
56
|
-
|
|
86
|
+
modified_objective = modules.step(modified_objective)
|
|
57
87
|
|
|
58
88
|
# modified_var.loss and grad refers to loss and grad of a modified objective
|
|
59
89
|
# so we only take the update
|
|
60
|
-
|
|
90
|
+
objective.updates = modified_objective.updates
|
|
61
91
|
|
|
62
|
-
# or just
|
|
92
|
+
# or just set closure to a modified one
|
|
93
|
+
# update already calls self.pre_step
|
|
63
94
|
else:
|
|
64
|
-
|
|
65
|
-
|
|
95
|
+
self.update(objective)
|
|
96
|
+
self.apply(objective) # does nothing unless overridden
|
|
66
97
|
|
|
67
|
-
return
|
|
98
|
+
return objective
|