torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +43 -33
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +48 -52
- torchzero/core/module.py +130 -50
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +2 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +15 -12
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +32 -32
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +50 -48
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +10 -10
- torchzero/modules/quasi_newton/sg2.py +19 -19
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +49 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +57 -90
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +157 -177
- torchzero/modules/second_order/rsn.py +106 -96
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +10 -10
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +93 -69
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -3,21 +3,14 @@ from typing import Literal
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
Module,
|
|
10
|
-
Transform,
|
|
11
|
-
Var,
|
|
12
|
-
apply_transform,
|
|
13
|
-
)
|
|
14
|
-
from ...utils import TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
15
|
-
from ..line_search import LineSearchBase
|
|
6
|
+
from ...core import Chainable, TensorTransform
|
|
7
|
+
|
|
8
|
+
from ...utils import TensorList, safe_dict_update_, unpack_dicts, unpack_states
|
|
16
9
|
from ..quasi_newton.quasi_newton import HessianUpdateStrategy
|
|
17
10
|
from ..functional import safe_clip
|
|
18
11
|
|
|
19
12
|
|
|
20
|
-
class ConguateGradientBase(
|
|
13
|
+
class ConguateGradientBase(TensorTransform, ABC):
|
|
21
14
|
"""Base class for conjugate gradient methods. The only difference between them is how beta is calculated.
|
|
22
15
|
|
|
23
16
|
This is an abstract class, to use it, subclass it and override `get_beta`.
|
|
@@ -52,13 +45,8 @@ class ConguateGradientBase(Transform, ABC):
|
|
|
52
45
|
"""
|
|
53
46
|
def __init__(self, defaults, clip_beta: bool, restart_interval: int | None | Literal['auto'], inner: Chainable | None = None):
|
|
54
47
|
if defaults is None: defaults = {}
|
|
55
|
-
defaults
|
|
56
|
-
defaults
|
|
57
|
-
super().__init__(defaults, uses_grad=False)
|
|
58
|
-
|
|
59
|
-
if inner is not None:
|
|
60
|
-
self.set_child('inner', inner)
|
|
61
|
-
|
|
48
|
+
safe_dict_update_(defaults, dict(restart_interval=restart_interval, clip_beta=clip_beta))
|
|
49
|
+
super().__init__(defaults, inner=inner)
|
|
62
50
|
|
|
63
51
|
def reset_for_online(self):
|
|
64
52
|
super().reset_for_online()
|
|
@@ -74,40 +62,38 @@ class ConguateGradientBase(Transform, ABC):
|
|
|
74
62
|
"""returns beta"""
|
|
75
63
|
|
|
76
64
|
@torch.no_grad
|
|
77
|
-
def
|
|
78
|
-
tensors =
|
|
79
|
-
params =
|
|
80
|
-
|
|
81
|
-
step = self.global_state.get('step', 0) + 1
|
|
82
|
-
self.global_state['step'] = step
|
|
65
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
66
|
+
tensors = TensorList(tensors)
|
|
67
|
+
params = TensorList(params)
|
|
68
|
+
self.increment_counter("step", start=0)
|
|
83
69
|
|
|
84
70
|
# initialize on first step
|
|
85
|
-
if self.global_state.get('stage',
|
|
71
|
+
if self.global_state.get('stage', "first step") == "first update":
|
|
86
72
|
g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
|
|
87
73
|
d_prev.copy_(tensors)
|
|
88
74
|
g_prev.copy_(tensors)
|
|
89
75
|
self.initialize(params, tensors)
|
|
90
|
-
self.global_state['stage'] =
|
|
76
|
+
self.global_state['stage'] = "first apply"
|
|
91
77
|
|
|
92
78
|
else:
|
|
93
79
|
# if `update_tensors` was called multiple times before `apply_tensors`,
|
|
94
80
|
# stage becomes 2
|
|
95
|
-
self.global_state['stage'] =
|
|
81
|
+
self.global_state['stage'] = "initialized"
|
|
96
82
|
|
|
97
83
|
@torch.no_grad
|
|
98
|
-
def
|
|
99
|
-
tensors =
|
|
84
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
85
|
+
tensors = TensorList(tensors)
|
|
100
86
|
step = self.global_state['step']
|
|
101
87
|
|
|
102
|
-
|
|
103
|
-
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
|
|
88
|
+
assert self.global_state['stage'] != "first update"
|
|
104
89
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
90
|
+
# on 1st apply we don't have previous gradients
|
|
91
|
+
# so just return tensors
|
|
92
|
+
if self.global_state['stage'] == "first apply":
|
|
93
|
+
self.global_state['stage'] = "initialized"
|
|
108
94
|
return tensors
|
|
109
95
|
|
|
110
|
-
params =
|
|
96
|
+
params = TensorList(params)
|
|
111
97
|
g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
|
|
112
98
|
|
|
113
99
|
# get beta
|
|
@@ -119,10 +105,13 @@ class ConguateGradientBase(Transform, ABC):
|
|
|
119
105
|
dir = tensors.add_(d_prev.mul_(beta))
|
|
120
106
|
d_prev.copy_(dir)
|
|
121
107
|
|
|
122
|
-
# resetting
|
|
108
|
+
# resetting every `reset_interval` steps, use step+1 to not reset on 1st step
|
|
109
|
+
# so if reset_interval=2, then 1st step collects g_prev and d_prev, then
|
|
110
|
+
# two steps will happen until reset.
|
|
123
111
|
restart_interval = settings[0]['restart_interval']
|
|
124
112
|
if restart_interval == 'auto': restart_interval = tensors.global_numel() + 1
|
|
125
|
-
|
|
113
|
+
|
|
114
|
+
if restart_interval is not None and (step + 1) % restart_interval == 0:
|
|
126
115
|
self.state.clear()
|
|
127
116
|
self.global_state.clear()
|
|
128
117
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Those are various ideas of mine plus some other modules that I decided not to move to other sub-packages for whatever reason. This is generally less tested and shouldn't be used."""
|
|
2
|
+
from .coordinate_momentum import CoordinateMomentum
|
|
2
3
|
from .curveball import CurveBall
|
|
3
4
|
|
|
4
5
|
# from dct import DCTProjection
|
|
@@ -6,14 +7,9 @@ from .fft import FFTProjection
|
|
|
6
7
|
from .gradmin import GradMin
|
|
7
8
|
from .higher_order_newton import HigherOrderNewton
|
|
8
9
|
from .l_infinity import InfinityNormTrustRegion
|
|
9
|
-
from .momentum import (
|
|
10
|
-
CoordinateMomentum,
|
|
11
|
-
NesterovEMASquared,
|
|
12
|
-
PrecenteredEMASquared,
|
|
13
|
-
SqrtNesterovEMASquared,
|
|
14
|
-
)
|
|
15
10
|
from .newton_solver import NewtonSolver
|
|
16
11
|
from .newtonnewton import NewtonNewton
|
|
17
12
|
from .reduce_outward_lr import ReduceOutwardLR
|
|
18
13
|
from .scipy_newton_cg import ScipyNewtonCG
|
|
14
|
+
from .spsa1 import SPSA1
|
|
19
15
|
from .structural_projections import BlockPartition, TensorizeProjection
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import TensorTransform
|
|
4
|
+
from ...utils import NumberList, TensorList, unpack_states
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def coordinate_momentum_(
|
|
8
|
+
tensors: TensorList,
|
|
9
|
+
velocity_: TensorList,
|
|
10
|
+
p: float | NumberList,
|
|
11
|
+
):
|
|
12
|
+
"""
|
|
13
|
+
sets `velocity_` to p% random values from `tensors`.
|
|
14
|
+
|
|
15
|
+
Returns `velocity_`
|
|
16
|
+
"""
|
|
17
|
+
mask = tensors.bernoulli_like(p).as_bool()
|
|
18
|
+
velocity_.masked_set_(mask, tensors)
|
|
19
|
+
return velocity_
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class CoordinateMomentum(TensorTransform):
|
|
23
|
+
"""Maintains a momentum buffer, on each step each value in the buffer has ``p`` chance to be updated with the new value.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
p (float, optional): _description_. Defaults to 0.1.
|
|
27
|
+
"""
|
|
28
|
+
def __init__(self, p: float = 0.1):
|
|
29
|
+
defaults = dict(p=p)
|
|
30
|
+
super().__init__(defaults)
|
|
31
|
+
|
|
32
|
+
@torch.no_grad
|
|
33
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
34
|
+
p = NumberList(s['p'] for s in settings)
|
|
35
|
+
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
36
|
+
return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()
|
|
@@ -1,25 +1,25 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
|
-
|
|
2
|
+
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import
|
|
6
|
-
from ...utils import NumberList, TensorList,
|
|
7
|
-
|
|
5
|
+
from ...core import Chainable, Transform, step, HVPMethod
|
|
6
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
7
|
+
|
|
8
8
|
|
|
9
9
|
def curveball(
|
|
10
10
|
tensors: TensorList,
|
|
11
11
|
z_: TensorList,
|
|
12
|
-
|
|
12
|
+
Hzz: TensorList,
|
|
13
13
|
momentum: float | NumberList,
|
|
14
14
|
precond_lr: float | NumberList,
|
|
15
15
|
):
|
|
16
16
|
"""returns z_, clone it!!! (no just negate it)"""
|
|
17
|
-
delta =
|
|
17
|
+
delta = Hzz + tensors
|
|
18
18
|
z_.mul_(momentum).sub_(delta.mul_(precond_lr)) # z ← ρz − βΔ
|
|
19
19
|
return z_
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
class CurveBall(
|
|
22
|
+
class CurveBall(Transform):
|
|
23
23
|
"""CurveBall method from https://arxiv.org/pdf/1805.08095#page=4.09.
|
|
24
24
|
|
|
25
25
|
For now this implementation does not include automatic ρ, α and β hyper-parameters in closed form, therefore it is expected to underperform compared to official implementation (https://github.com/jotaf98/pytorch-curveball/tree/master) so I moved this to experimental.
|
|
@@ -36,7 +36,7 @@ class CurveBall(Module):
|
|
|
36
36
|
self,
|
|
37
37
|
precond_lr: float=1e-3,
|
|
38
38
|
momentum: float=0.9,
|
|
39
|
-
hvp_method:
|
|
39
|
+
hvp_method: HVPMethod = "autograd",
|
|
40
40
|
h: float = 1e-3,
|
|
41
41
|
reg: float = 1,
|
|
42
42
|
inner: Chainable | None = None,
|
|
@@ -44,46 +44,30 @@ class CurveBall(Module):
|
|
|
44
44
|
defaults = dict(precond_lr=precond_lr, momentum=momentum, hvp_method=hvp_method, h=h, reg=reg)
|
|
45
45
|
super().__init__(defaults)
|
|
46
46
|
|
|
47
|
-
|
|
47
|
+
self.set_child('inner', inner)
|
|
48
48
|
|
|
49
49
|
@torch.no_grad
|
|
50
|
-
def
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
h = settings['h']
|
|
50
|
+
def apply_states(self, objective, states, settings):
|
|
51
|
+
params = objective.params
|
|
52
|
+
fs = settings[0]
|
|
53
|
+
hvp_method = fs['hvp_method']
|
|
54
|
+
h = fs['h']
|
|
56
55
|
|
|
57
|
-
precond_lr, momentum, reg =
|
|
56
|
+
precond_lr, momentum, reg = unpack_dicts(settings, 'precond_lr', 'momentum', 'reg', cls=NumberList)
|
|
58
57
|
|
|
59
|
-
|
|
60
|
-
closure = var.closure
|
|
58
|
+
closure = objective.closure
|
|
61
59
|
assert closure is not None
|
|
62
60
|
|
|
63
|
-
z, Hz =
|
|
64
|
-
|
|
65
|
-
if hvp_method == 'autograd':
|
|
66
|
-
grad = var.get_grad(create_graph=True)
|
|
67
|
-
Hvp = hvp(params, grad, z)
|
|
68
|
-
|
|
69
|
-
elif hvp_method == 'forward':
|
|
70
|
-
loss, Hvp = hvp_fd_forward(closure, params, z, h=h, g_0=var.get_grad(), normalize=True)
|
|
71
|
-
|
|
72
|
-
elif hvp_method == 'central':
|
|
73
|
-
loss, Hvp = hvp_fd_central(closure, params, z, h=h, normalize=True)
|
|
74
|
-
|
|
75
|
-
else:
|
|
76
|
-
raise ValueError(hvp_method)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
Hz.set_(Hvp + z*reg)
|
|
61
|
+
z, Hz = unpack_states(states, params, 'z', 'Hz', cls=TensorList)
|
|
62
|
+
Hz, _ = objective.hessian_vector_product(z, rgrad=None, at_x0=True, hvp_method=hvp_method, h=h)
|
|
80
63
|
|
|
64
|
+
Hz = TensorList(Hz)
|
|
65
|
+
Hzz = Hz.add_(z * reg)
|
|
81
66
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
update = apply_transform(self.children['inner'], update, params, grads=var.grad, var=var)
|
|
67
|
+
objective = self.inner_step("inner", objective, must_exist=False)
|
|
68
|
+
updates = objective.get_updates()
|
|
85
69
|
|
|
86
|
-
z = curveball(TensorList(
|
|
87
|
-
|
|
70
|
+
z = curveball(TensorList(updates), z, Hzz, momentum=momentum, precond_lr=precond_lr)
|
|
71
|
+
objective.updates = z.neg()
|
|
88
72
|
|
|
89
|
-
return
|
|
73
|
+
return objective
|
|
@@ -5,7 +5,7 @@ from typing import Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Module,
|
|
8
|
+
from ...core import Module, Objective, Chainable
|
|
9
9
|
from ...utils import NumberList, TensorList
|
|
10
10
|
from ...utils.derivatives import jacobian_wrt
|
|
11
11
|
from ..grad_approximation import GradApproximator, GradTarget
|
|
@@ -43,7 +43,7 @@ class GradMin(Reformulation):
|
|
|
43
43
|
super().__init__(defaults, modules=modules)
|
|
44
44
|
|
|
45
45
|
@torch.no_grad
|
|
46
|
-
def closure(self, backward, closure, params,
|
|
46
|
+
def closure(self, backward, closure, params, objective):
|
|
47
47
|
settings = self.settings[params[0]]
|
|
48
48
|
loss_term = settings['loss_term']
|
|
49
49
|
relative = settings['relative']
|
|
@@ -1,21 +1,12 @@
|
|
|
1
|
-
import itertools
|
|
2
1
|
import math
|
|
3
|
-
import warnings
|
|
4
|
-
from collections.abc import Callable
|
|
5
|
-
from contextlib import nullcontext
|
|
6
|
-
from functools import partial
|
|
7
2
|
from typing import Any, Literal
|
|
8
3
|
|
|
9
4
|
import numpy as np
|
|
10
5
|
import scipy.optimize
|
|
11
6
|
import torch
|
|
12
7
|
|
|
13
|
-
from ...core import
|
|
8
|
+
from ...core import DerivativesMethod, Module
|
|
14
9
|
from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
|
|
15
|
-
from ...utils.derivatives import (
|
|
16
|
-
flatten_jacobian,
|
|
17
|
-
jacobian_wrt,
|
|
18
|
-
)
|
|
19
10
|
|
|
20
11
|
_LETTERS = 'abcdefghijklmnopqrstuvwxyz'
|
|
21
12
|
def _poly_eval(s: np.ndarray, c, derivatives):
|
|
@@ -195,22 +186,22 @@ class HigherOrderNewton(Module):
|
|
|
195
186
|
max_attempts = 10,
|
|
196
187
|
boundary_tol: float = 1e-2,
|
|
197
188
|
de_iters: int | None = None,
|
|
198
|
-
|
|
189
|
+
derivatives_method: DerivativesMethod = "batched_autograd",
|
|
199
190
|
):
|
|
200
191
|
if init is None:
|
|
201
192
|
if trust_method == 'bounds': init = 1
|
|
202
193
|
else: init = 0.1
|
|
203
194
|
|
|
204
|
-
defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init,
|
|
195
|
+
defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, de_iters=de_iters, max_attempts=max_attempts, boundary_tol=boundary_tol, rho_good=rho_good, rho_bad=rho_bad, derivatives_method=derivatives_method)
|
|
205
196
|
super().__init__(defaults)
|
|
206
197
|
|
|
207
198
|
@torch.no_grad
|
|
208
|
-
def
|
|
209
|
-
params = TensorList(
|
|
210
|
-
closure =
|
|
199
|
+
def apply(self, objective):
|
|
200
|
+
params = TensorList(objective.params)
|
|
201
|
+
closure = objective.closure
|
|
211
202
|
if closure is None: raise RuntimeError('HigherOrderNewton requires closure')
|
|
212
203
|
|
|
213
|
-
settings = self.
|
|
204
|
+
settings = self.defaults
|
|
214
205
|
order = settings['order']
|
|
215
206
|
nplus = settings['nplus']
|
|
216
207
|
nminus = settings['nminus']
|
|
@@ -219,31 +210,12 @@ class HigherOrderNewton(Module):
|
|
|
219
210
|
trust_method = settings['trust_method']
|
|
220
211
|
de_iters = settings['de_iters']
|
|
221
212
|
max_attempts = settings['max_attempts']
|
|
222
|
-
vectorize = settings['vectorize']
|
|
223
213
|
boundary_tol = settings['boundary_tol']
|
|
224
214
|
rho_good = settings['rho_good']
|
|
225
215
|
rho_bad = settings['rho_bad']
|
|
226
216
|
|
|
227
217
|
# ------------------------ calculate grad and hessian ------------------------ #
|
|
228
|
-
|
|
229
|
-
loss = var.loss = var.loss_approx = closure(False)
|
|
230
|
-
|
|
231
|
-
g_list = torch.autograd.grad(loss, params, create_graph=True)
|
|
232
|
-
var.grad = list(g_list)
|
|
233
|
-
|
|
234
|
-
g = torch.cat([t.ravel() for t in g_list])
|
|
235
|
-
n = g.numel()
|
|
236
|
-
derivatives = [g]
|
|
237
|
-
T = g # current derivatives tensor
|
|
238
|
-
|
|
239
|
-
# get all derivative up to order
|
|
240
|
-
for o in range(2, order + 1):
|
|
241
|
-
is_last = o == order
|
|
242
|
-
T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
|
|
243
|
-
with torch.no_grad() if is_last else nullcontext():
|
|
244
|
-
# the shape is (ndim, ) * order
|
|
245
|
-
T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
|
|
246
|
-
derivatives.append(T)
|
|
218
|
+
loss, *derivatives = objective.derivatives(order=order, at_x0=True, method=self.defaults["derivatives_method"])
|
|
247
219
|
|
|
248
220
|
x0 = torch.cat([p.ravel() for p in params])
|
|
249
221
|
|
|
@@ -301,7 +273,8 @@ class HigherOrderNewton(Module):
|
|
|
301
273
|
vec_to_tensors_(x0, params)
|
|
302
274
|
reduction = loss - loss_star
|
|
303
275
|
|
|
304
|
-
rho = reduction / (max(pred_reduction,
|
|
276
|
+
rho = reduction / (max(pred_reduction, finfo.tiny * 2)) # pyright:ignore[reportArgumentType]
|
|
277
|
+
|
|
305
278
|
# failed step
|
|
306
279
|
if rho < rho_bad:
|
|
307
280
|
self.global_state['trust_region'] = trust_value * nminus
|
|
@@ -320,8 +293,9 @@ class HigherOrderNewton(Module):
|
|
|
320
293
|
assert x_star is not None
|
|
321
294
|
if success:
|
|
322
295
|
difference = vec_to_tensors(x0 - x_star, params)
|
|
323
|
-
|
|
296
|
+
objective.updates = list(difference)
|
|
324
297
|
else:
|
|
325
|
-
|
|
326
|
-
|
|
298
|
+
objective.updates = params.zeros_like()
|
|
299
|
+
|
|
300
|
+
return objective
|
|
327
301
|
|
|
@@ -1,11 +1,10 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
from typing import Any
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Any
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Chainable, Modular, Module,
|
|
7
|
-
from ...utils import TensorList
|
|
8
|
-
from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
|
|
6
|
+
from ...core import Chainable, Modular, Module, step, HVPMethod
|
|
7
|
+
from ...utils import TensorList
|
|
9
8
|
from ..quasi_newton import LBFGS
|
|
10
9
|
|
|
11
10
|
|
|
@@ -19,24 +18,26 @@ class NewtonSolver(Module):
|
|
|
19
18
|
tol:float | None=1e-3,
|
|
20
19
|
reg: float = 0,
|
|
21
20
|
warm_start=True,
|
|
22
|
-
hvp_method:
|
|
21
|
+
hvp_method: HVPMethod = "autograd",
|
|
23
22
|
reset_solver: bool = False,
|
|
24
23
|
h: float= 1e-3,
|
|
24
|
+
|
|
25
25
|
inner: Chainable | None = None,
|
|
26
26
|
):
|
|
27
|
-
defaults =
|
|
28
|
-
|
|
27
|
+
defaults = locals().copy()
|
|
28
|
+
del defaults['self'], defaults['inner']
|
|
29
|
+
super().__init__(defaults)
|
|
29
30
|
|
|
30
|
-
|
|
31
|
-
self.set_child('inner', inner)
|
|
31
|
+
self.set_child("inner", inner)
|
|
32
32
|
|
|
33
33
|
self._num_hvps = 0
|
|
34
34
|
self._num_hvps_last_step = 0
|
|
35
35
|
|
|
36
36
|
@torch.no_grad
|
|
37
|
-
def
|
|
38
|
-
|
|
39
|
-
|
|
37
|
+
def apply(self, objective):
|
|
38
|
+
|
|
39
|
+
params = TensorList(objective.params)
|
|
40
|
+
closure = objective.closure
|
|
40
41
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
41
42
|
|
|
42
43
|
settings = self.settings[params[0]]
|
|
@@ -44,51 +45,19 @@ class NewtonSolver(Module):
|
|
|
44
45
|
maxiter = settings['maxiter']
|
|
45
46
|
maxiter1 = settings['maxiter1']
|
|
46
47
|
tol = settings['tol']
|
|
47
|
-
reg = settings['reg']
|
|
48
48
|
hvp_method = settings['hvp_method']
|
|
49
49
|
warm_start = settings['warm_start']
|
|
50
50
|
h = settings['h']
|
|
51
51
|
reset_solver = settings['reset_solver']
|
|
52
52
|
|
|
53
53
|
self._num_hvps_last_step = 0
|
|
54
|
-
# ---------------------- Hessian vector product function --------------------- #
|
|
55
|
-
if hvp_method == 'autograd':
|
|
56
|
-
grad = var.get_grad(create_graph=True)
|
|
57
|
-
|
|
58
|
-
def H_mm(x):
|
|
59
|
-
self._num_hvps_last_step += 1
|
|
60
|
-
with torch.enable_grad():
|
|
61
|
-
Hvp = TensorList(hvp(params, grad, x, retain_graph=True))
|
|
62
|
-
if reg != 0: Hvp = Hvp + (x*reg)
|
|
63
|
-
return Hvp
|
|
64
|
-
|
|
65
|
-
else:
|
|
66
|
-
|
|
67
|
-
with torch.enable_grad():
|
|
68
|
-
grad = var.get_grad()
|
|
69
|
-
|
|
70
|
-
if hvp_method == 'forward':
|
|
71
|
-
def H_mm(x):
|
|
72
|
-
self._num_hvps_last_step += 1
|
|
73
|
-
Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
|
|
74
|
-
if reg != 0: Hvp = Hvp + (x*reg)
|
|
75
|
-
return Hvp
|
|
76
|
-
|
|
77
|
-
elif hvp_method == 'central':
|
|
78
|
-
def H_mm(x):
|
|
79
|
-
self._num_hvps_last_step += 1
|
|
80
|
-
Hvp = TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
|
|
81
|
-
if reg != 0: Hvp = Hvp + (x*reg)
|
|
82
|
-
return Hvp
|
|
83
|
-
|
|
84
|
-
else:
|
|
85
|
-
raise ValueError(hvp_method)
|
|
86
54
|
|
|
55
|
+
# ---------------------- Hessian vector product function --------------------- #
|
|
56
|
+
_, H_mv = objective.list_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
|
|
87
57
|
|
|
88
58
|
# -------------------------------- inner step -------------------------------- #
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
b = as_tensorlist(apply_transform(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, var=var))
|
|
59
|
+
objective = self.inner_step("inner", objective, must_exist=False)
|
|
60
|
+
b = TensorList(objective.get_updates())
|
|
92
61
|
|
|
93
62
|
# ---------------------------------- run cg ---------------------------------- #
|
|
94
63
|
x0 = None
|
|
@@ -112,7 +81,7 @@ class NewtonSolver(Module):
|
|
|
112
81
|
solver = self.global_state['solver']
|
|
113
82
|
|
|
114
83
|
def lstsq_closure(backward=True):
|
|
115
|
-
Hx =
|
|
84
|
+
Hx = H_mv(x).detach()
|
|
116
85
|
# loss = (Hx-b).pow(2).global_mean()
|
|
117
86
|
# if backward:
|
|
118
87
|
# solver.zero_grad()
|
|
@@ -122,7 +91,7 @@ class NewtonSolver(Module):
|
|
|
122
91
|
loss = residual.pow(2).global_mean()
|
|
123
92
|
if backward:
|
|
124
93
|
with torch.no_grad():
|
|
125
|
-
H_residual =
|
|
94
|
+
H_residual = H_mv(residual)
|
|
126
95
|
n = residual.global_numel()
|
|
127
96
|
x.set_grad_((2.0 / n) * H_residual)
|
|
128
97
|
|
|
@@ -143,8 +112,8 @@ class NewtonSolver(Module):
|
|
|
143
112
|
assert x0 is not None
|
|
144
113
|
x0.copy_(x)
|
|
145
114
|
|
|
146
|
-
|
|
115
|
+
objective.updates = x.detach()
|
|
147
116
|
self._num_hvps += self._num_hvps_last_step
|
|
148
|
-
return
|
|
117
|
+
return objective
|
|
149
118
|
|
|
150
119
|
|
|
@@ -7,7 +7,8 @@ from typing import Literal
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
from ...core import Chainable, Module,
|
|
10
|
+
from ...core import Chainable, Module, step
|
|
11
|
+
from ...linalg.linear_operator import Dense
|
|
11
12
|
from ...utils import TensorList, vec_to_tensors
|
|
12
13
|
from ...utils.derivatives import (
|
|
13
14
|
flatten_jacobian,
|
|
@@ -19,7 +20,7 @@ from ..second_order.newton import (
|
|
|
19
20
|
_least_squares_solve,
|
|
20
21
|
_lu_solve,
|
|
21
22
|
)
|
|
22
|
-
|
|
23
|
+
|
|
23
24
|
|
|
24
25
|
class NewtonNewton(Module):
|
|
25
26
|
"""Applies Newton-like preconditioning to Newton step.
|
|
@@ -51,9 +52,10 @@ class NewtonNewton(Module):
|
|
|
51
52
|
super().__init__(defaults)
|
|
52
53
|
|
|
53
54
|
@torch.no_grad
|
|
54
|
-
def update(self,
|
|
55
|
-
|
|
56
|
-
|
|
55
|
+
def update(self, objective):
|
|
56
|
+
|
|
57
|
+
params = TensorList(objective.params)
|
|
58
|
+
closure = objective.closure
|
|
57
59
|
if closure is None: raise RuntimeError('NewtonNewton requires closure')
|
|
58
60
|
|
|
59
61
|
settings = self.settings[params[0]]
|
|
@@ -66,9 +68,9 @@ class NewtonNewton(Module):
|
|
|
66
68
|
# ------------------------ calculate grad and hessian ------------------------ #
|
|
67
69
|
Hs = []
|
|
68
70
|
with torch.enable_grad():
|
|
69
|
-
loss =
|
|
71
|
+
loss = objective.loss = objective.loss_approx = closure(False)
|
|
70
72
|
g_list = torch.autograd.grad(loss, params, create_graph=True)
|
|
71
|
-
|
|
73
|
+
objective.grads = list(g_list)
|
|
72
74
|
|
|
73
75
|
xp = torch.cat([t.ravel() for t in g_list])
|
|
74
76
|
I = torch.eye(xp.numel(), dtype=xp.dtype, device=xp.device)
|
|
@@ -93,13 +95,14 @@ class NewtonNewton(Module):
|
|
|
93
95
|
self.global_state['xp'] = xp.nan_to_num_(0,0,0)
|
|
94
96
|
|
|
95
97
|
@torch.no_grad
|
|
96
|
-
def apply(self,
|
|
97
|
-
params =
|
|
98
|
+
def apply(self, objective):
|
|
99
|
+
params = objective.params
|
|
98
100
|
xp = self.global_state['xp']
|
|
99
|
-
|
|
100
|
-
return
|
|
101
|
+
objective.updates = vec_to_tensors(xp, params)
|
|
102
|
+
return objective
|
|
101
103
|
|
|
102
|
-
|
|
104
|
+
@torch.no_grad
|
|
105
|
+
def get_H(self, objective=...):
|
|
103
106
|
Hs = self.global_state["Hs"]
|
|
104
107
|
if len(Hs) == 1: return Dense(Hs[0])
|
|
105
108
|
return Dense(torch.linalg.multi_dot(self.global_state["Hs"])) # pylint:disable=not-callable
|
|
@@ -1,28 +1,28 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from ...core import
|
|
3
|
+
from ...core import TensorTransform
|
|
4
4
|
from ...utils import TensorList, unpack_states, unpack_dicts
|
|
5
5
|
|
|
6
|
-
class ReduceOutwardLR(
|
|
6
|
+
class ReduceOutwardLR(TensorTransform):
|
|
7
7
|
"""When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
|
|
8
8
|
|
|
9
9
|
This means updates that move weights towards zero have higher learning rates.
|
|
10
10
|
|
|
11
|
-
|
|
11
|
+
Warning:
|
|
12
12
|
This sounded good but after testing turns out it sucks.
|
|
13
13
|
"""
|
|
14
|
-
def __init__(self, mul = 0.5, use_grad=False, invert=False
|
|
14
|
+
def __init__(self, mul = 0.5, use_grad=False, invert=False):
|
|
15
15
|
defaults = dict(mul=mul, use_grad=use_grad, invert=invert)
|
|
16
|
-
super().__init__(defaults, uses_grad=use_grad
|
|
16
|
+
super().__init__(defaults, uses_grad=use_grad)
|
|
17
17
|
|
|
18
18
|
@torch.no_grad
|
|
19
|
-
def
|
|
19
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
20
20
|
params = TensorList(params)
|
|
21
21
|
tensors = TensorList(tensors)
|
|
22
22
|
|
|
23
23
|
mul = [s['mul'] for s in settings]
|
|
24
24
|
s = settings[0]
|
|
25
|
-
use_grad =
|
|
25
|
+
use_grad = self._uses_grad
|
|
26
26
|
invert = s['invert']
|
|
27
27
|
|
|
28
28
|
if use_grad: cur = grads
|