torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,52 +1,19 @@
|
|
|
1
|
-
from typing import Literal
|
|
2
|
-
from collections.abc import Callable
|
|
3
1
|
import torch
|
|
4
2
|
|
|
5
|
-
from ...core import
|
|
6
|
-
from ...utils import NumberList, TensorList,
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
h_exp_avg_: TensorList,
|
|
12
|
-
beta1: float | NumberList,
|
|
13
|
-
beta2: float | NumberList,
|
|
14
|
-
update_freq: int,
|
|
15
|
-
precond_scale: float | NumberList,
|
|
16
|
-
clip: float | NumberList,
|
|
17
|
-
eps: float | NumberList,
|
|
18
|
-
step: int
|
|
19
|
-
):
|
|
20
|
-
# momentum
|
|
21
|
-
exp_avg_.lerp_(tensors, 1-beta1)
|
|
22
|
-
|
|
23
|
-
# update preconditioner
|
|
24
|
-
if step % update_freq == 0:
|
|
25
|
-
assert h is not None
|
|
26
|
-
h_exp_avg_.lerp_(h, 1-beta2)
|
|
27
|
-
|
|
28
|
-
else:
|
|
29
|
-
assert h is None
|
|
30
|
-
|
|
31
|
-
denom = (h_exp_avg_ * precond_scale).clip_(min=eps)
|
|
32
|
-
return (exp_avg_ / denom).clip_(-clip, clip)
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
class SophiaH(Module):
|
|
3
|
+
from ...core import Chainable, Transform, HVPMethod
|
|
4
|
+
from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SophiaH(Transform):
|
|
36
9
|
"""SophiaH optimizer from https://arxiv.org/abs/2305.14342
|
|
37
10
|
|
|
38
11
|
This is similar to Adam, but the second momentum is replaced by an exponential moving average of randomized hessian diagonal estimates, and the update is agressively clipped.
|
|
39
12
|
|
|
40
|
-
|
|
41
|
-
In most cases SophiaH should be the first module in the chain because it relies on autograd. Use the
|
|
13
|
+
Notes:
|
|
14
|
+
- In most cases SophiaH should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply SophiaH preconditioning to another module's output.
|
|
42
15
|
|
|
43
|
-
|
|
44
|
-
If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
|
|
45
|
-
|
|
46
|
-
.. note::
|
|
47
|
-
This module requires the a closure passed to the optimizer step,
|
|
48
|
-
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
49
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
16
|
+
- This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
|
|
50
17
|
|
|
51
18
|
Args:
|
|
52
19
|
beta1 (float, optional): first momentum. Defaults to 0.96.
|
|
@@ -60,46 +27,48 @@ class SophiaH(Module):
|
|
|
60
27
|
eps (float, optional):
|
|
61
28
|
clips hessian diagonal esimate to be no less than this value. Defaults to 1e-12.
|
|
62
29
|
hvp_method (str, optional):
|
|
63
|
-
Determines how Hessian-vector products are
|
|
64
|
-
|
|
65
|
-
- ``"
|
|
66
|
-
|
|
67
|
-
- ``"
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
30
|
+
Determines how Hessian-vector products are computed.
|
|
31
|
+
|
|
32
|
+
- ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
|
|
33
|
+
- ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
|
|
34
|
+
- ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
35
|
+
- ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
36
|
+
|
|
37
|
+
Defaults to ``"autograd"``.
|
|
38
|
+
h (float, optional):
|
|
39
|
+
The step size for finite difference if ``hvp_method`` is
|
|
40
|
+
``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
|
|
74
41
|
n_samples (int, optional):
|
|
75
42
|
number of hessian-vector products with random vectors to evaluate each time when updating
|
|
76
43
|
the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
|
|
77
44
|
seed (int | None, optional): seed for random vectors. Defaults to None.
|
|
78
45
|
inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
|
|
79
46
|
|
|
80
|
-
Examples:
|
|
81
|
-
Using SophiaH:
|
|
47
|
+
### Examples:
|
|
82
48
|
|
|
83
|
-
|
|
49
|
+
Using SophiaH:
|
|
84
50
|
|
|
85
|
-
|
|
86
|
-
model.parameters(),
|
|
87
|
-
tz.m.SophiaH(),
|
|
88
|
-
tz.m.LR(0.1)
|
|
89
|
-
)
|
|
51
|
+
```python
|
|
90
52
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
53
|
+
opt = tz.Modular(
|
|
54
|
+
model.parameters(),
|
|
55
|
+
tz.m.SophiaH(),
|
|
56
|
+
tz.m.LR(0.1)
|
|
57
|
+
)
|
|
58
|
+
```
|
|
94
59
|
|
|
95
|
-
|
|
60
|
+
SophiaH preconditioner can be applied to any other module by passing it to the ``inner`` argument.
|
|
61
|
+
Turn off SophiaH's first momentum to get just the preconditioning. Here is an example of applying
|
|
62
|
+
SophiaH preconditioning to nesterov momentum (``tz.m.NAG``):
|
|
96
63
|
|
|
97
|
-
|
|
98
|
-
model.parameters(),
|
|
99
|
-
tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
|
|
100
|
-
tz.m.LR(0.1)
|
|
101
|
-
)
|
|
64
|
+
```python
|
|
102
65
|
|
|
66
|
+
opt = tz.Modular(
|
|
67
|
+
model.parameters(),
|
|
68
|
+
tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
|
|
69
|
+
tz.m.LR(0.1)
|
|
70
|
+
)
|
|
71
|
+
```
|
|
103
72
|
"""
|
|
104
73
|
def __init__(
|
|
105
74
|
self,
|
|
@@ -109,77 +78,84 @@ class SophiaH(Module):
|
|
|
109
78
|
precond_scale: float = 1,
|
|
110
79
|
clip: float = 1,
|
|
111
80
|
eps: float = 1e-12,
|
|
112
|
-
hvp_method:
|
|
113
|
-
|
|
81
|
+
hvp_method: HVPMethod = 'autograd',
|
|
82
|
+
distribution: Distributions = 'gaussian',
|
|
83
|
+
h: float = 1e-3,
|
|
114
84
|
n_samples = 1,
|
|
85
|
+
zHz: bool = True,
|
|
86
|
+
debias: bool = False,
|
|
115
87
|
seed: int | None = None,
|
|
116
|
-
|
|
88
|
+
|
|
89
|
+
exp_avg_tfm: Chainable | None = None,
|
|
90
|
+
D_exp_avg_tfm: Chainable | None = None,
|
|
117
91
|
):
|
|
118
|
-
defaults =
|
|
92
|
+
defaults = locals().copy()
|
|
93
|
+
del defaults['self'], defaults['exp_avg_tfm'], defaults["D_exp_avg_tfm"]
|
|
119
94
|
super().__init__(defaults)
|
|
120
95
|
|
|
121
|
-
|
|
122
|
-
|
|
96
|
+
self.set_child('exp_avg', exp_avg_tfm)
|
|
97
|
+
self.set_child('D_exp_avg', D_exp_avg_tfm)
|
|
123
98
|
|
|
124
99
|
@torch.no_grad
|
|
125
|
-
def
|
|
126
|
-
params =
|
|
127
|
-
settings = self.settings[params[0]]
|
|
128
|
-
hvp_method = settings['hvp_method']
|
|
129
|
-
fd_h = settings['fd_h']
|
|
130
|
-
update_freq = settings['update_freq']
|
|
131
|
-
n_samples = settings['n_samples']
|
|
100
|
+
def update_states(self, objective, states, settings):
|
|
101
|
+
params = objective.params
|
|
132
102
|
|
|
133
|
-
|
|
134
|
-
generator = None
|
|
135
|
-
if seed is not None:
|
|
136
|
-
if 'generator' not in self.global_state:
|
|
137
|
-
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
138
|
-
generator = self.global_state['generator']
|
|
103
|
+
beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
|
|
139
104
|
|
|
140
|
-
|
|
141
|
-
'beta1', 'beta2', 'precond_scale', 'clip', 'eps', cls=NumberList)
|
|
105
|
+
exp_avg, D_exp_avg = unpack_states(states, params, 'exp_avg', 'D_exp_avg', cls=TensorList)
|
|
142
106
|
|
|
143
|
-
|
|
107
|
+
step = self.increment_counter("step", start=0) # 0 on 1st update
|
|
144
108
|
|
|
145
|
-
|
|
146
|
-
|
|
109
|
+
# ---------------------------- hutchinson hessian ---------------------------- #
|
|
110
|
+
fs = settings[0]
|
|
111
|
+
update_freq = fs['update_freq']
|
|
147
112
|
|
|
148
|
-
closure = var.closure
|
|
149
|
-
assert closure is not None
|
|
150
|
-
|
|
151
|
-
h = None
|
|
152
113
|
if step % update_freq == 0:
|
|
114
|
+
self.increment_counter("num_Ds", start=1)
|
|
115
|
+
|
|
116
|
+
D, _ = objective.hutchinson_hessian(
|
|
117
|
+
rgrad = None,
|
|
118
|
+
at_x0 = True,
|
|
119
|
+
n_samples = fs['n_samples'],
|
|
120
|
+
distribution = fs['distribution'],
|
|
121
|
+
hvp_method = fs['hvp_method'],
|
|
122
|
+
h = fs['h'],
|
|
123
|
+
zHz = fs["zHz"],
|
|
124
|
+
generator = self.get_generator(params[0].device, fs["seed"]),
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
D_exp_avg.lerp_(D, weight=1-beta2)
|
|
128
|
+
|
|
129
|
+
# --------------------------------- momentum --------------------------------- #
|
|
130
|
+
tensors = objective.get_updates() # do this after hutchinson to not disturb autograd
|
|
131
|
+
exp_avg.lerp_(tensors, 1-beta1)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@torch.no_grad
|
|
135
|
+
def apply_states(self, objective, states, settings):
|
|
136
|
+
params = objective.params
|
|
137
|
+
|
|
138
|
+
beta1, beta2, eps, precond_scale, clip = unpack_dicts(
|
|
139
|
+
settings, 'beta1', 'beta2', 'eps', 'precond_scale', 'clip', cls=NumberList)
|
|
140
|
+
|
|
141
|
+
exp_avg, D_exp_avg = unpack_states(states, params, 'exp_avg', 'D_exp_avg')
|
|
142
|
+
|
|
143
|
+
# ---------------------------------- debias ---------------------------------- #
|
|
144
|
+
if settings[0]["debias"]:
|
|
145
|
+
bias_correction1 = 1.0 - (beta1 ** (self.global_state["step"] + 1))
|
|
146
|
+
bias_correction2 = 1.0 - (beta2 ** self.global_state["num_Ds"])
|
|
147
|
+
|
|
148
|
+
exp_avg = exp_avg / bias_correction1
|
|
149
|
+
D_exp_avg = D_exp_avg / bias_correction2
|
|
150
|
+
|
|
151
|
+
# -------------------------------- transforms -------------------------------- #
|
|
152
|
+
exp_avg = TensorList(self.inner_step_tensors(
|
|
153
|
+
"exp_avg", tensors=exp_avg, clone=True, objective=objective, must_exist=False))
|
|
154
|
+
|
|
155
|
+
D_exp_avg = TensorList(self.inner_step_tensors(
|
|
156
|
+
"D_exp_avg", tensors=D_exp_avg, clone=True, objective=objective, must_exist=False))
|
|
153
157
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
|
|
159
|
-
h=fd_h, normalize=True, retain_grad=i < n_samples-1)
|
|
160
|
-
Hvp = tuple(Hvp)
|
|
161
|
-
|
|
162
|
-
if h is None: h = Hvp
|
|
163
|
-
else: torch._foreach_add_(h, Hvp)
|
|
164
|
-
|
|
165
|
-
assert h is not None
|
|
166
|
-
if n_samples > 1: torch._foreach_div_(h, n_samples)
|
|
167
|
-
|
|
168
|
-
update = var.get_update()
|
|
169
|
-
if 'inner' in self.children:
|
|
170
|
-
update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
|
|
171
|
-
|
|
172
|
-
var.update = sophia_H(
|
|
173
|
-
tensors=TensorList(update),
|
|
174
|
-
h=TensorList(h) if h is not None else None,
|
|
175
|
-
exp_avg_=exp_avg,
|
|
176
|
-
h_exp_avg_=h_exp_avg,
|
|
177
|
-
beta1=beta1,
|
|
178
|
-
beta2=beta2,
|
|
179
|
-
update_freq=update_freq,
|
|
180
|
-
precond_scale=precond_scale,
|
|
181
|
-
clip=clip,
|
|
182
|
-
eps=eps,
|
|
183
|
-
step=step,
|
|
184
|
-
)
|
|
185
|
-
return var
|
|
158
|
+
# ------------------------------ compute update ------------------------------ #
|
|
159
|
+
denom = D_exp_avg.lazy_mul(precond_scale).clip(min=eps)
|
|
160
|
+
objective.updates = (exp_avg / denom).clip_(-clip, clip)
|
|
161
|
+
return objective
|
|
@@ -5,7 +5,7 @@ from typing import Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Module,
|
|
8
|
+
from ...core import Module, TensorTransform
|
|
9
9
|
from ...utils import Metrics, NumberList, TensorList
|
|
10
10
|
from ...utils.metrics import _METRICS
|
|
11
11
|
|
|
@@ -150,7 +150,7 @@ def normalize_grads_(
|
|
|
150
150
|
_clip_norm_(grads, min=None, max=None, norm_value=norm_value, ord=ord, dim=dim, inverse_dims=inverse_dims, min_size=min_size)
|
|
151
151
|
|
|
152
152
|
|
|
153
|
-
class ClipValue(
|
|
153
|
+
class ClipValue(TensorTransform):
|
|
154
154
|
"""Clips update magnitude to be within ``(-value, value)`` range.
|
|
155
155
|
|
|
156
156
|
Args:
|
|
@@ -180,17 +180,17 @@ class ClipValue(Transform):
|
|
|
180
180
|
```
|
|
181
181
|
|
|
182
182
|
"""
|
|
183
|
-
def __init__(self, value: float
|
|
183
|
+
def __init__(self, value: float):
|
|
184
184
|
defaults = dict(value=value)
|
|
185
|
-
super().__init__(defaults
|
|
185
|
+
super().__init__(defaults)
|
|
186
186
|
|
|
187
187
|
@torch.no_grad
|
|
188
|
-
def
|
|
188
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
189
189
|
value = [s['value'] for s in settings]
|
|
190
190
|
return TensorList(tensors).clip_([-v for v in value], value)
|
|
191
191
|
|
|
192
|
-
class ClipNorm(
|
|
193
|
-
"""Clips update norm to be no larger than
|
|
192
|
+
class ClipNorm(TensorTransform):
|
|
193
|
+
"""Clips update norm to be no larger than ``value``.
|
|
194
194
|
|
|
195
195
|
Args:
|
|
196
196
|
max_norm (float): value to clip norm to.
|
|
@@ -236,13 +236,12 @@ class ClipNorm(Transform):
|
|
|
236
236
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
237
237
|
inverse_dims: bool = False,
|
|
238
238
|
min_size: int = 1,
|
|
239
|
-
target: Target = "update",
|
|
240
239
|
):
|
|
241
240
|
defaults = dict(max_norm=max_norm,ord=ord,dim=dim,min_size=min_size,inverse_dims=inverse_dims)
|
|
242
|
-
super().__init__(defaults
|
|
241
|
+
super().__init__(defaults)
|
|
243
242
|
|
|
244
243
|
@torch.no_grad
|
|
245
|
-
def
|
|
244
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
246
245
|
max_norm = NumberList(s['max_norm'] for s in settings)
|
|
247
246
|
ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
|
|
248
247
|
_clip_norm_(
|
|
@@ -257,7 +256,7 @@ class ClipNorm(Transform):
|
|
|
257
256
|
)
|
|
258
257
|
return tensors
|
|
259
258
|
|
|
260
|
-
class Normalize(
|
|
259
|
+
class Normalize(TensorTransform):
|
|
261
260
|
"""Normalizes the update.
|
|
262
261
|
|
|
263
262
|
Args:
|
|
@@ -304,13 +303,12 @@ class Normalize(Transform):
|
|
|
304
303
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
305
304
|
inverse_dims: bool = False,
|
|
306
305
|
min_size: int = 1,
|
|
307
|
-
target: Target = "update",
|
|
308
306
|
):
|
|
309
307
|
defaults = dict(norm_value=norm_value,ord=ord,dim=dim,min_size=min_size, inverse_dims=inverse_dims)
|
|
310
|
-
super().__init__(defaults
|
|
308
|
+
super().__init__(defaults)
|
|
311
309
|
|
|
312
310
|
@torch.no_grad
|
|
313
|
-
def
|
|
311
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
314
312
|
norm_value = NumberList(s['norm_value'] for s in settings)
|
|
315
313
|
ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
|
|
316
314
|
|
|
@@ -362,7 +360,7 @@ def _centralize_(
|
|
|
362
360
|
return tensors_
|
|
363
361
|
|
|
364
362
|
|
|
365
|
-
class Centralize(
|
|
363
|
+
class Centralize(TensorTransform):
|
|
366
364
|
"""Centralizes the update.
|
|
367
365
|
|
|
368
366
|
Args:
|
|
@@ -395,13 +393,12 @@ class Centralize(Transform):
|
|
|
395
393
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
396
394
|
inverse_dims: bool = False,
|
|
397
395
|
min_size: int = 2,
|
|
398
|
-
target: Target = "update",
|
|
399
396
|
):
|
|
400
397
|
defaults = dict(dim=dim,min_size=min_size,inverse_dims=inverse_dims)
|
|
401
|
-
super().__init__(defaults
|
|
398
|
+
super().__init__(defaults)
|
|
402
399
|
|
|
403
400
|
@torch.no_grad
|
|
404
|
-
def
|
|
401
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
405
402
|
dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(settings[0])
|
|
406
403
|
|
|
407
404
|
_centralize_(tensors_ = TensorList(tensors), dim=dim, inverse_dims=inverse_dims, min_size=min_size)
|
|
@@ -1,13 +1,14 @@
|
|
|
1
|
+
from collections.abc import Iterable, Sequence
|
|
1
2
|
from operator import itemgetter
|
|
2
3
|
from typing import Literal
|
|
3
|
-
from collections.abc import Iterable, Sequence
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ...core import
|
|
8
|
-
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
7
|
+
from ...core import Chainable, TensorTransform, step
|
|
8
|
+
from ...utils import Metrics, NumberList, TensorList, unpack_dicts, unpack_states
|
|
9
|
+
|
|
9
10
|
|
|
10
|
-
class ClipNormByEMA(
|
|
11
|
+
class ClipNormByEMA(TensorTransform):
|
|
11
12
|
"""Clips norm to be no larger than the norm of an exponential moving average of past updates.
|
|
12
13
|
|
|
13
14
|
Args:
|
|
@@ -36,7 +37,7 @@ class ClipNormByEMA(Transform):
|
|
|
36
37
|
super().__init__(defaults, inner=inner)
|
|
37
38
|
|
|
38
39
|
@torch.no_grad
|
|
39
|
-
def
|
|
40
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
40
41
|
tensors = TensorList(tensors)
|
|
41
42
|
ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(settings[0])
|
|
42
43
|
|
|
@@ -83,7 +84,7 @@ class ClipNormByEMA(Transform):
|
|
|
83
84
|
self.global_state['denom'] = denom
|
|
84
85
|
|
|
85
86
|
@torch.no_grad
|
|
86
|
-
def
|
|
87
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
87
88
|
denom = self.global_state.pop('denom')
|
|
88
89
|
torch._foreach_div_(tensors, denom)
|
|
89
90
|
return tensors
|
|
@@ -106,45 +107,50 @@ class NormalizeByEMA(ClipNormByEMA):
|
|
|
106
107
|
|
|
107
108
|
# TODO Centralize by EMA?
|
|
108
109
|
|
|
109
|
-
class ClipValueByEMA(
|
|
110
|
+
class ClipValueByEMA(TensorTransform):
|
|
110
111
|
"""Clips magnitude of update to be no larger than magnitude of exponential moving average of past (unclipped) updates.
|
|
111
112
|
|
|
112
113
|
Args:
|
|
113
114
|
beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
|
|
114
115
|
ema_init (str, optional):
|
|
115
|
-
How to initialize exponential moving average on first step,
|
|
116
|
-
|
|
116
|
+
How to initialize exponential moving average on first step,
|
|
117
|
+
"update" to use the first update or "zeros". Defaults to 'zeros'.
|
|
118
|
+
exp_avg_tfm (Chainable | None, optional):
|
|
117
119
|
optional modules applied to exponential moving average before clipping by it. Defaults to None.
|
|
118
120
|
"""
|
|
119
121
|
def __init__(
|
|
120
122
|
self,
|
|
121
123
|
beta=0.99,
|
|
122
|
-
|
|
123
|
-
|
|
124
|
+
init: Literal['zeros', 'update'] = 'zeros',
|
|
125
|
+
|
|
124
126
|
inner: Chainable | None = None,
|
|
127
|
+
exp_avg_tfm:Chainable | None=None,
|
|
125
128
|
):
|
|
126
|
-
defaults = dict(beta=beta,
|
|
129
|
+
defaults = dict(beta=beta, init=init)
|
|
127
130
|
super().__init__(defaults, inner=inner)
|
|
128
131
|
|
|
129
|
-
|
|
130
|
-
self.set_child('ema_tfm', ema_tfm)
|
|
132
|
+
self.set_child('exp_avg', exp_avg_tfm)
|
|
131
133
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
134
|
+
def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
|
|
135
|
+
if setting["init"] == "zeros":
|
|
136
|
+
state["exp_avg"] = torch.zeros_like(tensor)
|
|
137
|
+
else:
|
|
138
|
+
state["exp_avg"] = tensor.abs()
|
|
135
139
|
|
|
136
|
-
|
|
140
|
+
@torch.no_grad
|
|
141
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
137
142
|
tensors = TensorList(tensors)
|
|
143
|
+
beta = unpack_dicts(settings, 'beta', cls=NumberList)
|
|
138
144
|
|
|
139
|
-
|
|
140
|
-
|
|
145
|
+
exp_avg = unpack_states(states, tensors, 'exp_avg', must_exist=True, cls=TensorList)
|
|
146
|
+
exp_avg.lerp_(tensors.abs(), 1-beta)
|
|
141
147
|
|
|
142
|
-
def
|
|
148
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
143
149
|
tensors = TensorList(tensors)
|
|
144
|
-
|
|
150
|
+
exp_avg = unpack_states(states, tensors, 'exp_avg')
|
|
145
151
|
|
|
146
|
-
|
|
147
|
-
|
|
152
|
+
exp_avg = TensorList(
|
|
153
|
+
self.inner_step_tensors("exp_avg", exp_avg, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
|
|
148
154
|
|
|
149
|
-
tensors.clip_(-
|
|
155
|
+
tensors.clip_(-exp_avg, exp_avg)
|
|
150
156
|
return tensors
|
|
@@ -2,11 +2,11 @@ from operator import itemgetter
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import
|
|
6
|
-
from ...utils import TensorList
|
|
5
|
+
from ...core import TensorTransform
|
|
6
|
+
from ...utils import TensorList
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class ClipValueGrowth(
|
|
9
|
+
class ClipValueGrowth(TensorTransform):
|
|
10
10
|
"""Clips update value magnitude growth.
|
|
11
11
|
|
|
12
12
|
Args:
|
|
@@ -27,13 +27,12 @@ class ClipValueGrowth(TensorwiseTransform):
|
|
|
27
27
|
mul: float | None = 1.5,
|
|
28
28
|
min_value: float | None = 1e-4,
|
|
29
29
|
max_decay: float | None = 2,
|
|
30
|
-
target: Target = "update",
|
|
31
30
|
):
|
|
32
31
|
defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay)
|
|
33
|
-
super().__init__(defaults
|
|
32
|
+
super().__init__(defaults)
|
|
34
33
|
|
|
35
34
|
|
|
36
|
-
def
|
|
35
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
37
36
|
add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(setting)
|
|
38
37
|
add: float | None
|
|
39
38
|
|
|
@@ -115,7 +114,7 @@ def norm_growth_clip_(
|
|
|
115
114
|
return tensor_.div_(denom), new_prev_norm, denom
|
|
116
115
|
|
|
117
116
|
|
|
118
|
-
class ClipNormGrowth(
|
|
117
|
+
class ClipNormGrowth(TensorTransform):
|
|
119
118
|
"""Clips update norm growth.
|
|
120
119
|
|
|
121
120
|
Args:
|
|
@@ -130,7 +129,7 @@ class ClipNormGrowth(Transform):
|
|
|
130
129
|
Next norm is at most :code:`max(previous norm * mul, max_decay)`.
|
|
131
130
|
Defaults to 2.
|
|
132
131
|
ord (float, optional): norm order. Defaults to 2.
|
|
133
|
-
|
|
132
|
+
tensorwise (bool, optional):
|
|
134
133
|
if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
|
|
135
134
|
target (Target, optional): what to set on var. Defaults to "update".
|
|
136
135
|
"""
|
|
@@ -141,19 +140,17 @@ class ClipNormGrowth(Transform):
|
|
|
141
140
|
min_value: float | None = 1e-4,
|
|
142
141
|
max_decay: float | None = 2,
|
|
143
142
|
ord: float = 2,
|
|
144
|
-
|
|
145
|
-
target: Target = "update",
|
|
143
|
+
tensorwise=True,
|
|
146
144
|
):
|
|
147
|
-
defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay, ord=ord,
|
|
148
|
-
super().__init__(defaults
|
|
145
|
+
defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay, ord=ord, tensorwise=tensorwise)
|
|
146
|
+
super().__init__(defaults)
|
|
149
147
|
|
|
150
148
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
parameterwise = settings[0]['parameterwise']
|
|
149
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
150
|
+
tensorwise = settings[0]['tensorwise']
|
|
154
151
|
tensors = TensorList(tensors)
|
|
155
152
|
|
|
156
|
-
if
|
|
153
|
+
if tensorwise:
|
|
157
154
|
ts = tensors
|
|
158
155
|
stts = states
|
|
159
156
|
stns = settings
|
|
@@ -180,7 +177,7 @@ class ClipNormGrowth(Transform):
|
|
|
180
177
|
ord = setting['ord'],
|
|
181
178
|
)
|
|
182
179
|
|
|
183
|
-
if not
|
|
180
|
+
if not tensorwise:
|
|
184
181
|
tensors.from_vec_(ts[0])
|
|
185
182
|
|
|
186
183
|
return tensors
|