torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -13,27 +13,27 @@ class HomotopyBase(Module):
|
|
|
13
13
|
"""transform the loss"""
|
|
14
14
|
|
|
15
15
|
@torch.no_grad
|
|
16
|
-
def
|
|
17
|
-
if
|
|
18
|
-
|
|
16
|
+
def apply(self, objective):
|
|
17
|
+
if objective.loss is not None:
|
|
18
|
+
objective.loss = self.loss_transform(objective.loss)
|
|
19
19
|
|
|
20
|
-
closure =
|
|
20
|
+
closure = objective.closure
|
|
21
21
|
if closure is None: raise RuntimeError("SquareHomotopy requires closure")
|
|
22
22
|
|
|
23
23
|
def homotopy_closure(backward=True):
|
|
24
24
|
if backward:
|
|
25
25
|
with torch.enable_grad():
|
|
26
26
|
loss = self.loss_transform(closure(False))
|
|
27
|
-
grad = torch.autograd.grad(loss,
|
|
28
|
-
for p,g in zip(
|
|
27
|
+
grad = torch.autograd.grad(loss, objective.params, allow_unused=True)
|
|
28
|
+
for p,g in zip(objective.params, grad):
|
|
29
29
|
p.grad = g
|
|
30
30
|
else:
|
|
31
31
|
loss = self.loss_transform(closure(False))
|
|
32
32
|
|
|
33
33
|
return loss
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
return
|
|
35
|
+
objective.closure = homotopy_closure
|
|
36
|
+
return objective
|
|
37
37
|
|
|
38
38
|
class SquareHomotopy(HomotopyBase):
|
|
39
39
|
def __init__(self): super().__init__()
|
|
@@ -57,3 +57,11 @@ class LambdaHomotopy(HomotopyBase):
|
|
|
57
57
|
super().__init__(defaults)
|
|
58
58
|
|
|
59
59
|
def loss_transform(self, loss): return self.defaults['fn'](loss)
|
|
60
|
+
|
|
61
|
+
class FixedLossHomotopy(HomotopyBase):
|
|
62
|
+
def __init__(self, value: float = 1):
|
|
63
|
+
defaults = dict(value=value)
|
|
64
|
+
super().__init__(defaults)
|
|
65
|
+
|
|
66
|
+
def loss_transform(self, loss): return loss / loss.detach().clip(min=torch.finfo(loss.dtype).tiny * 2)
|
|
67
|
+
|
torchzero/modules/misc/misc.py
CHANGED
|
@@ -6,7 +6,7 @@ from typing import Literal
|
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
|
|
9
|
-
from ...core import Chainable, Module,
|
|
9
|
+
from ...core import Chainable, Module, TensorTransform, Transform, Objective
|
|
10
10
|
from ...utils import (
|
|
11
11
|
Distributions,
|
|
12
12
|
Metrics,
|
|
@@ -19,15 +19,15 @@ from ...utils import (
|
|
|
19
19
|
)
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
class Previous(
|
|
22
|
+
class Previous(TensorTransform):
|
|
23
23
|
"""Maintains an update from n steps back, for example if n=1, returns previous update"""
|
|
24
|
-
def __init__(self, n=1
|
|
24
|
+
def __init__(self, n=1):
|
|
25
25
|
defaults = dict(n=n)
|
|
26
|
-
super().__init__(
|
|
26
|
+
super().__init__(defaults=defaults)
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
@torch.no_grad
|
|
30
|
-
def
|
|
30
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
31
31
|
n = setting['n']
|
|
32
32
|
|
|
33
33
|
if 'history' not in state:
|
|
@@ -38,13 +38,13 @@ class Previous(TensorwiseTransform):
|
|
|
38
38
|
return state['history'][0]
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
class LastDifference(
|
|
41
|
+
class LastDifference(TensorTransform):
|
|
42
42
|
"""Outputs difference between past two updates."""
|
|
43
|
-
def __init__(self,
|
|
44
|
-
super().__init__(
|
|
43
|
+
def __init__(self,):
|
|
44
|
+
super().__init__()
|
|
45
45
|
|
|
46
46
|
@torch.no_grad
|
|
47
|
-
def
|
|
47
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
48
48
|
prev_tensors = unpack_states(states, tensors, 'prev_tensors') # initialized to 0
|
|
49
49
|
difference = torch._foreach_sub(tensors, prev_tensors)
|
|
50
50
|
for p, c in zip(prev_tensors, tensors): p.set_(c)
|
|
@@ -53,16 +53,16 @@ class LastDifference(Transform):
|
|
|
53
53
|
class LastGradDifference(Module):
|
|
54
54
|
"""Outputs difference between past two gradients."""
|
|
55
55
|
def __init__(self):
|
|
56
|
-
super().__init__(
|
|
56
|
+
super().__init__()
|
|
57
57
|
|
|
58
58
|
@torch.no_grad
|
|
59
|
-
def
|
|
60
|
-
grad =
|
|
61
|
-
prev_grad = self.get_state(
|
|
59
|
+
def apply(self, objective):
|
|
60
|
+
grad = objective.get_grads()
|
|
61
|
+
prev_grad = self.get_state(objective.params, 'prev_grad') # initialized to 0
|
|
62
62
|
difference = torch._foreach_sub(grad, prev_grad)
|
|
63
63
|
for p, c in zip(prev_grad, grad): p.copy_(c)
|
|
64
|
-
|
|
65
|
-
return
|
|
64
|
+
objective.updates = list(difference)
|
|
65
|
+
return objective
|
|
66
66
|
|
|
67
67
|
class LastParamDifference(Module):
|
|
68
68
|
"""Outputs difference between past two parameters, which is the effective previous update."""
|
|
@@ -70,36 +70,36 @@ class LastParamDifference(Module):
|
|
|
70
70
|
super().__init__({})
|
|
71
71
|
|
|
72
72
|
@torch.no_grad
|
|
73
|
-
def
|
|
74
|
-
params =
|
|
75
|
-
prev_params = self.get_state(
|
|
73
|
+
def apply(self, objective):
|
|
74
|
+
params = objective.params
|
|
75
|
+
prev_params = self.get_state(objective.params, 'prev_params') # initialized to 0
|
|
76
76
|
difference = torch._foreach_sub(params, prev_params)
|
|
77
77
|
for p, c in zip(prev_params, params): p.copy_(c)
|
|
78
|
-
|
|
79
|
-
return
|
|
78
|
+
objective.updates = list(difference)
|
|
79
|
+
return objective
|
|
80
80
|
|
|
81
81
|
|
|
82
82
|
|
|
83
|
-
class LastProduct(
|
|
83
|
+
class LastProduct(TensorTransform):
|
|
84
84
|
"""Outputs difference between past two updates."""
|
|
85
|
-
def __init__(self
|
|
86
|
-
super().__init__(
|
|
85
|
+
def __init__(self):
|
|
86
|
+
super().__init__()
|
|
87
87
|
|
|
88
88
|
@torch.no_grad
|
|
89
|
-
def
|
|
89
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
90
90
|
prev = unpack_states(states, tensors, 'prev', init=torch.ones_like) # initialized to 1 for prod
|
|
91
91
|
prod = torch._foreach_mul(tensors, prev)
|
|
92
92
|
for p, c in zip(prev, tensors): p.set_(c)
|
|
93
93
|
return prod
|
|
94
94
|
|
|
95
|
-
class LastRatio(
|
|
96
|
-
"""Outputs ratio between past two updates, the numerator is determined by
|
|
97
|
-
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur'
|
|
95
|
+
class LastRatio(TensorTransform):
|
|
96
|
+
"""Outputs ratio between past two updates, the numerator is determined by ``numerator`` argument."""
|
|
97
|
+
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur'):
|
|
98
98
|
defaults = dict(numerator=numerator)
|
|
99
|
-
super().__init__(defaults
|
|
99
|
+
super().__init__(defaults)
|
|
100
100
|
|
|
101
101
|
@torch.no_grad
|
|
102
|
-
def
|
|
102
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
103
103
|
prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
|
|
104
104
|
numerator = settings[0]['numerator']
|
|
105
105
|
if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
|
|
@@ -107,14 +107,14 @@ class LastRatio(Transform):
|
|
|
107
107
|
for p, c in zip(prev, tensors): p.set_(c)
|
|
108
108
|
return ratio
|
|
109
109
|
|
|
110
|
-
class LastAbsoluteRatio(
|
|
111
|
-
"""Outputs ratio between absolute values of past two updates the numerator is determined by
|
|
112
|
-
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8
|
|
110
|
+
class LastAbsoluteRatio(TensorTransform):
|
|
111
|
+
"""Outputs ratio between absolute values of past two updates the numerator is determined by ``numerator`` argument."""
|
|
112
|
+
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8):
|
|
113
113
|
defaults = dict(numerator=numerator, eps=eps)
|
|
114
|
-
super().__init__(defaults
|
|
114
|
+
super().__init__(defaults)
|
|
115
115
|
|
|
116
116
|
@torch.no_grad
|
|
117
|
-
def
|
|
117
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
118
118
|
prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
|
|
119
119
|
numerator = settings[0]['numerator']
|
|
120
120
|
eps = NumberList(s['eps'] for s in settings)
|
|
@@ -127,139 +127,139 @@ class LastAbsoluteRatio(Transform):
|
|
|
127
127
|
for p, c in zip(prev, tensors): p.set_(c)
|
|
128
128
|
return ratio
|
|
129
129
|
|
|
130
|
-
class GradSign(
|
|
130
|
+
class GradSign(TensorTransform):
|
|
131
131
|
"""Copies gradient sign to update."""
|
|
132
|
-
def __init__(self
|
|
133
|
-
super().__init__(
|
|
132
|
+
def __init__(self):
|
|
133
|
+
super().__init__(uses_grad=True)
|
|
134
134
|
|
|
135
135
|
@torch.no_grad
|
|
136
|
-
def
|
|
136
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
137
137
|
assert grads is not None
|
|
138
138
|
return [t.copysign_(g) for t,g in zip(tensors, grads)]
|
|
139
139
|
|
|
140
|
-
class UpdateSign(
|
|
140
|
+
class UpdateSign(TensorTransform):
|
|
141
141
|
"""Outputs gradient with sign copied from the update."""
|
|
142
|
-
def __init__(self
|
|
143
|
-
super().__init__(
|
|
142
|
+
def __init__(self):
|
|
143
|
+
super().__init__(uses_grad=True)
|
|
144
144
|
|
|
145
145
|
@torch.no_grad
|
|
146
|
-
def
|
|
146
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
147
147
|
assert grads is not None
|
|
148
148
|
return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place
|
|
149
149
|
|
|
150
|
-
class GraftToGrad(
|
|
150
|
+
class GraftToGrad(TensorTransform):
|
|
151
151
|
"""Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient."""
|
|
152
|
-
def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6
|
|
152
|
+
def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6):
|
|
153
153
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
154
|
-
super().__init__(defaults, uses_grad=True
|
|
154
|
+
super().__init__(defaults, uses_grad=True)
|
|
155
155
|
|
|
156
156
|
@torch.no_grad
|
|
157
|
-
def
|
|
157
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
158
158
|
assert grads is not None
|
|
159
159
|
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
160
160
|
return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
161
161
|
|
|
162
|
-
class GraftGradToUpdate(
|
|
162
|
+
class GraftGradToUpdate(TensorTransform):
|
|
163
163
|
"""Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update."""
|
|
164
|
-
def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6
|
|
164
|
+
def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6):
|
|
165
165
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
166
|
-
super().__init__(defaults, uses_grad=True
|
|
166
|
+
super().__init__(defaults, uses_grad=True)
|
|
167
167
|
|
|
168
168
|
@torch.no_grad
|
|
169
|
-
def
|
|
169
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
170
170
|
assert grads is not None
|
|
171
171
|
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
172
172
|
return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
173
173
|
|
|
174
174
|
|
|
175
|
-
class GraftToParams(
|
|
176
|
-
"""Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than
|
|
177
|
-
def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-4
|
|
175
|
+
class GraftToParams(TensorTransform):
|
|
176
|
+
"""Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than ``eps``."""
|
|
177
|
+
def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-4):
|
|
178
178
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
179
|
-
super().__init__(defaults
|
|
179
|
+
super().__init__(defaults)
|
|
180
180
|
|
|
181
181
|
@torch.no_grad
|
|
182
|
-
def
|
|
182
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
183
183
|
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
184
184
|
return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
185
185
|
|
|
186
|
-
class Relative(
|
|
187
|
-
"""Multiplies update by absolute parameter values to make it relative to their magnitude,
|
|
188
|
-
def __init__(self, min_value:float = 1e-4
|
|
186
|
+
class Relative(TensorTransform):
|
|
187
|
+
"""Multiplies update by absolute parameter values to make it relative to their magnitude, ``min_value`` is minimum allowed value to avoid getting stuck at 0."""
|
|
188
|
+
def __init__(self, min_value:float = 1e-4):
|
|
189
189
|
defaults = dict(min_value=min_value)
|
|
190
|
-
super().__init__(defaults
|
|
190
|
+
super().__init__(defaults)
|
|
191
191
|
|
|
192
192
|
@torch.no_grad
|
|
193
|
-
def
|
|
193
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
194
194
|
mul = TensorList(params).abs().clamp_([s['min_value'] for s in settings])
|
|
195
195
|
torch._foreach_mul_(tensors, mul)
|
|
196
196
|
return tensors
|
|
197
197
|
|
|
198
198
|
class FillLoss(Module):
|
|
199
|
-
"""Outputs tensors filled with loss value times
|
|
199
|
+
"""Outputs tensors filled with loss value times ``alpha``"""
|
|
200
200
|
def __init__(self, alpha: float = 1, backward: bool = True):
|
|
201
201
|
defaults = dict(alpha=alpha, backward=backward)
|
|
202
202
|
super().__init__(defaults)
|
|
203
203
|
|
|
204
204
|
@torch.no_grad
|
|
205
|
-
def
|
|
206
|
-
alpha = self.get_settings(
|
|
207
|
-
loss =
|
|
208
|
-
|
|
209
|
-
return
|
|
210
|
-
|
|
211
|
-
class MulByLoss(
|
|
212
|
-
"""Multiplies update by loss times
|
|
213
|
-
def __init__(self, alpha: float = 1, min_value:float = 1e-
|
|
205
|
+
def apply(self, objective):
|
|
206
|
+
alpha = self.get_settings(objective.params, 'alpha')
|
|
207
|
+
loss = objective.get_loss(backward=self.defaults['backward'])
|
|
208
|
+
objective.updates = [torch.full_like(p, loss*a) for p,a in zip(objective.params, alpha)]
|
|
209
|
+
return objective
|
|
210
|
+
|
|
211
|
+
class MulByLoss(TensorTransform):
|
|
212
|
+
"""Multiplies update by loss times ``alpha``"""
|
|
213
|
+
def __init__(self, alpha: float = 1, min_value:float = 1e-16, backward: bool = True):
|
|
214
214
|
defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
|
|
215
|
-
super().__init__(defaults)
|
|
215
|
+
super().__init__(defaults, uses_loss=True)
|
|
216
216
|
|
|
217
217
|
@torch.no_grad
|
|
218
|
-
def
|
|
219
|
-
|
|
220
|
-
|
|
218
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
219
|
+
assert loss is not None
|
|
220
|
+
alpha, min_value = unpack_dicts(settings, 'alpha', 'min_value')
|
|
221
221
|
mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
222
|
-
torch._foreach_mul_(
|
|
223
|
-
return
|
|
222
|
+
torch._foreach_mul_(tensors, mul)
|
|
223
|
+
return tensors
|
|
224
224
|
|
|
225
|
-
class DivByLoss(
|
|
226
|
-
"""Divides update by loss times
|
|
227
|
-
def __init__(self, alpha: float = 1, min_value:float = 1e-
|
|
225
|
+
class DivByLoss(TensorTransform):
|
|
226
|
+
"""Divides update by loss times ``alpha``"""
|
|
227
|
+
def __init__(self, alpha: float = 1, min_value:float = 1e-16, backward: bool = True):
|
|
228
228
|
defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
|
|
229
|
-
super().__init__(defaults)
|
|
229
|
+
super().__init__(defaults, uses_loss=True)
|
|
230
230
|
|
|
231
231
|
@torch.no_grad
|
|
232
|
-
def
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
torch._foreach_div_(
|
|
237
|
-
return
|
|
232
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
233
|
+
assert loss is not None
|
|
234
|
+
alpha, min_value = unpack_dicts(settings, 'alpha', 'min_value')
|
|
235
|
+
denom = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
236
|
+
torch._foreach_div_(tensors, denom)
|
|
237
|
+
return tensors
|
|
238
238
|
|
|
239
239
|
|
|
240
|
-
class NoiseSign(
|
|
240
|
+
class NoiseSign(TensorTransform):
|
|
241
241
|
"""Outputs random tensors with sign copied from the update."""
|
|
242
242
|
def __init__(self, distribution:Distributions = 'normal', variance:float | None = None):
|
|
243
243
|
defaults = dict(distribution=distribution, variance=variance)
|
|
244
|
-
super().__init__(defaults
|
|
244
|
+
super().__init__(defaults)
|
|
245
245
|
|
|
246
246
|
@torch.no_grad
|
|
247
|
-
def
|
|
247
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
248
248
|
variance = unpack_dicts(settings, 'variance')
|
|
249
249
|
return TensorList(tensors).sample_like(settings[0]['distribution'], variance=variance).copysign_(tensors)
|
|
250
250
|
|
|
251
|
-
class HpuEstimate(
|
|
251
|
+
class HpuEstimate(TensorTransform):
|
|
252
252
|
"""returns ``y/||s||``, where ``y`` is difference between current and previous update (gradient), ``s`` is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update."""
|
|
253
253
|
def __init__(self):
|
|
254
254
|
defaults = dict()
|
|
255
|
-
super().__init__(defaults
|
|
255
|
+
super().__init__(defaults)
|
|
256
256
|
|
|
257
257
|
def reset_for_online(self):
|
|
258
258
|
super().reset_for_online()
|
|
259
259
|
self.clear_state_keys('prev_params', 'prev_update')
|
|
260
260
|
|
|
261
261
|
@torch.no_grad
|
|
262
|
-
def
|
|
262
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
263
263
|
prev_params, prev_update = self.get_state(params, 'prev_params', 'prev_update') # initialized to 0
|
|
264
264
|
s = torch._foreach_sub(params, prev_params)
|
|
265
265
|
y = torch._foreach_sub(tensors, prev_update)
|
|
@@ -269,50 +269,48 @@ class HpuEstimate(Transform):
|
|
|
269
269
|
self.store(params, 'y', y)
|
|
270
270
|
|
|
271
271
|
@torch.no_grad
|
|
272
|
-
def
|
|
272
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
273
273
|
return [self.state[p]['y'] for p in params]
|
|
274
274
|
|
|
275
275
|
class RandomHvp(Module):
|
|
276
|
-
"""Returns a hessian-vector product with a random vector"""
|
|
276
|
+
"""Returns a hessian-vector product with a random vector, optionally times vector"""
|
|
277
277
|
|
|
278
278
|
def __init__(
|
|
279
279
|
self,
|
|
280
280
|
n_samples: int = 1,
|
|
281
281
|
distribution: Distributions = "normal",
|
|
282
282
|
update_freq: int = 1,
|
|
283
|
-
|
|
283
|
+
zHz: bool = False,
|
|
284
|
+
hvp_method: Literal["autograd", "fd_forward", "central"] = "autograd",
|
|
284
285
|
h=1e-3,
|
|
286
|
+
seed: int | None = None
|
|
285
287
|
):
|
|
286
|
-
defaults =
|
|
288
|
+
defaults = locals().copy()
|
|
289
|
+
del defaults['self']
|
|
287
290
|
super().__init__(defaults)
|
|
288
291
|
|
|
289
292
|
@torch.no_grad
|
|
290
|
-
def
|
|
291
|
-
params = TensorList(
|
|
292
|
-
settings = self.settings[params[0]]
|
|
293
|
-
n_samples = settings['n_samples']
|
|
294
|
-
distribution = settings['distribution']
|
|
295
|
-
hvp_method = settings['hvp_method']
|
|
296
|
-
h = settings['h']
|
|
297
|
-
update_freq = settings['update_freq']
|
|
293
|
+
def apply(self, objective):
|
|
294
|
+
params = TensorList(objective.params)
|
|
298
295
|
|
|
299
296
|
step = self.global_state.get('step', 0)
|
|
300
297
|
self.global_state['step'] = step + 1
|
|
301
298
|
|
|
302
299
|
D = None
|
|
300
|
+
update_freq = self.defaults['update_freq']
|
|
303
301
|
if step % update_freq == 0:
|
|
304
302
|
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
303
|
+
D, _ = objective.hutchinson_hessian(
|
|
304
|
+
rgrad = None,
|
|
305
|
+
at_x0 = True,
|
|
306
|
+
n_samples = self.defaults['n_samples'],
|
|
307
|
+
distribution = self.defaults['distribution'],
|
|
308
|
+
hvp_method = self.defaults['hvp_method'],
|
|
309
|
+
h = self.defaults['h'],
|
|
310
|
+
zHz = self.defaults["zHz"],
|
|
311
|
+
generator = self.get_generator(params[0].device, self.defaults["seed"]),
|
|
312
|
+
)
|
|
314
313
|
|
|
315
|
-
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
316
314
|
if update_freq != 1:
|
|
317
315
|
assert D is not None
|
|
318
316
|
D_buf = self.get_state(params, "D", cls=TensorList)
|
|
@@ -321,8 +319,8 @@ class RandomHvp(Module):
|
|
|
321
319
|
if D is None:
|
|
322
320
|
D = self.get_state(params, "D", cls=TensorList)
|
|
323
321
|
|
|
324
|
-
|
|
325
|
-
return
|
|
322
|
+
objective.updates = list(D)
|
|
323
|
+
return objective
|
|
326
324
|
|
|
327
325
|
@torch.no_grad
|
|
328
326
|
def _load_best_parameters(params: Sequence[torch.Tensor], best_params: Sequence[torch.Tensor]):
|
|
@@ -370,14 +368,14 @@ class SaveBest(Module):
|
|
|
370
368
|
super().__init__()
|
|
371
369
|
|
|
372
370
|
@torch.no_grad
|
|
373
|
-
def
|
|
374
|
-
loss = tofloat(
|
|
371
|
+
def apply(self, objective):
|
|
372
|
+
loss = tofloat(objective.get_loss(False))
|
|
375
373
|
lowest_loss = self.global_state.get('lowest_loss', float("inf"))
|
|
376
374
|
|
|
377
375
|
if loss < lowest_loss:
|
|
378
376
|
self.global_state['lowest_loss'] = loss
|
|
379
|
-
best_params =
|
|
380
|
-
|
|
381
|
-
|
|
377
|
+
best_params = objective.attrs['best_params'] = [p.clone() for p in objective.params]
|
|
378
|
+
objective.attrs['best_loss'] = loss
|
|
379
|
+
objective.attrs['load_best_params'] = partial(_load_best_parameters, params=objective.params, best_params=best_params)
|
|
382
380
|
|
|
383
|
-
return
|
|
381
|
+
return objective
|