torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -3,9 +3,8 @@ from typing import Literal
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Chainable,
|
|
7
|
-
from ...utils import NumberList, TensorList,
|
|
8
|
-
from ..functional import debiased_step_size
|
|
6
|
+
from ...core import Chainable, Transform, HVPMethod
|
|
7
|
+
from ...utils import NumberList, TensorList, Distributions, unpack_dicts, unpack_states
|
|
9
8
|
|
|
10
9
|
def _full_average(hvp: torch.Tensor):
|
|
11
10
|
if hvp.ndim >= 3: # Conv kernel
|
|
@@ -37,41 +36,7 @@ def _block_average(x: torch.Tensor, block_size: int | None, enable: bool):
|
|
|
37
36
|
return x
|
|
38
37
|
|
|
39
38
|
|
|
40
|
-
|
|
41
|
-
"""p is probability of a 1, other values will be -1."""
|
|
42
|
-
return torch.bernoulli(torch.full_like(tensor, p), generator = generator).mul_(2).sub_(1)
|
|
43
|
-
|
|
44
|
-
def adahessian(
|
|
45
|
-
tensors: TensorList,
|
|
46
|
-
D: TensorList | None,
|
|
47
|
-
exp_avg_: TensorList,
|
|
48
|
-
D_exp_avg_sq_: TensorList,
|
|
49
|
-
beta1: float | NumberList,
|
|
50
|
-
beta2: float | NumberList,
|
|
51
|
-
update_freq: int,
|
|
52
|
-
eps: float | NumberList,
|
|
53
|
-
hessian_power: float | NumberList,
|
|
54
|
-
step: int,
|
|
55
|
-
):
|
|
56
|
-
# momentum
|
|
57
|
-
exp_avg_.lerp_(tensors, 1-beta1)
|
|
58
|
-
|
|
59
|
-
# update preconditioner
|
|
60
|
-
if step % update_freq == 0:
|
|
61
|
-
assert D is not None
|
|
62
|
-
D_exp_avg_sq_.mul_(beta2).addcmul_(D, D, 1-beta2)
|
|
63
|
-
|
|
64
|
-
else:
|
|
65
|
-
assert D is None
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
denom = D_exp_avg_sq_.sqrt().pow_(hessian_power).add_(eps)
|
|
69
|
-
num = exp_avg_ * debiased_step_size(step+1, beta1, beta2)
|
|
70
|
-
|
|
71
|
-
return num.div_(denom)
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
class AdaHessian(Module):
|
|
39
|
+
class AdaHessian(Transform):
|
|
75
40
|
"""AdaHessian: An Adaptive Second Order Optimizer for Machine Learning (https://arxiv.org/abs/2006.00719)
|
|
76
41
|
|
|
77
42
|
This is similar to Adam, but the second momentum is replaced by square root of an exponential moving average of random hessian-vector products.
|
|
@@ -79,8 +44,6 @@ class AdaHessian(Module):
|
|
|
79
44
|
Notes:
|
|
80
45
|
- In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply AdaHessian preconditioning to another module's output.
|
|
81
46
|
|
|
82
|
-
- If you are using gradient estimators or reformulations, set ``hvp_method`` to "forward" or "central".
|
|
83
|
-
|
|
84
47
|
- This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
|
|
85
48
|
|
|
86
49
|
Args:
|
|
@@ -97,17 +60,17 @@ class AdaHessian(Module):
|
|
|
97
60
|
eps (float, optional):
|
|
98
61
|
division stability epsilon. Defaults to 1e-8.
|
|
99
62
|
hvp_method (str, optional):
|
|
100
|
-
Determines how
|
|
101
|
-
|
|
102
|
-
- ``"
|
|
103
|
-
|
|
104
|
-
- ``"
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
63
|
+
Determines how hessian-vector products are computed.
|
|
64
|
+
|
|
65
|
+
- ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
|
|
66
|
+
- ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
|
|
67
|
+
- ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
68
|
+
- ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
69
|
+
|
|
70
|
+
Defaults to ``"autograd"``.
|
|
71
|
+
h (float, optional):
|
|
72
|
+
The step size for finite difference if ``hvp_method`` is
|
|
73
|
+
``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
|
|
111
74
|
n_samples (int, optional):
|
|
112
75
|
number of hessian-vector products with random vectors to evaluate each time when updating
|
|
113
76
|
the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
|
|
@@ -123,7 +86,7 @@ class AdaHessian(Module):
|
|
|
123
86
|
Using AdaHessian:
|
|
124
87
|
|
|
125
88
|
```python
|
|
126
|
-
opt = tz.
|
|
89
|
+
opt = tz.Optimizer(
|
|
127
90
|
model.parameters(),
|
|
128
91
|
tz.m.AdaHessian(),
|
|
129
92
|
tz.m.LR(0.1)
|
|
@@ -134,7 +97,7 @@ class AdaHessian(Module):
|
|
|
134
97
|
Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying
|
|
135
98
|
AdaHessian preconditioning to nesterov momentum (``tz.m.NAG``):
|
|
136
99
|
```python
|
|
137
|
-
opt = tz.
|
|
100
|
+
opt = tz.Optimizer(
|
|
138
101
|
model.parameters(),
|
|
139
102
|
tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
|
|
140
103
|
tz.m.LR(0.1)
|
|
@@ -151,74 +114,82 @@ class AdaHessian(Module):
|
|
|
151
114
|
update_freq: int = 1,
|
|
152
115
|
eps: float = 1e-8,
|
|
153
116
|
hessian_power: float = 1,
|
|
154
|
-
|
|
155
|
-
|
|
117
|
+
distribution: Distributions = 'rademacher',
|
|
118
|
+
hvp_method: HVPMethod = 'autograd',
|
|
119
|
+
h: float = 1e-3,
|
|
156
120
|
n_samples = 1,
|
|
121
|
+
zHz: bool = True,
|
|
122
|
+
debias: bool = True,
|
|
157
123
|
seed: int | None = None,
|
|
158
|
-
|
|
124
|
+
|
|
125
|
+
exp_avg_tfm: Chainable | None = None,
|
|
126
|
+
D_exp_avg_sq_tfm: Chainable | None = None,
|
|
159
127
|
):
|
|
160
|
-
defaults =
|
|
128
|
+
defaults = locals().copy()
|
|
129
|
+
del defaults['self'], defaults["exp_avg_tfm"], defaults["D_exp_avg_sq_tfm"]
|
|
161
130
|
super().__init__(defaults)
|
|
162
131
|
|
|
163
|
-
|
|
164
|
-
|
|
132
|
+
self.set_child('exp_avg', exp_avg_tfm)
|
|
133
|
+
self.set_child('D_exp_avg_sq', D_exp_avg_sq_tfm)
|
|
165
134
|
|
|
166
135
|
@torch.no_grad
|
|
167
|
-
def
|
|
168
|
-
params =
|
|
169
|
-
settings = self.settings[params[0]]
|
|
170
|
-
hvp_method = settings['hvp_method']
|
|
171
|
-
fd_h = settings['fd_h']
|
|
172
|
-
update_freq = settings['update_freq']
|
|
173
|
-
n_samples = settings['n_samples']
|
|
136
|
+
def update_states(self, objective, states, settings):
|
|
137
|
+
params = objective.params
|
|
174
138
|
|
|
175
|
-
|
|
176
|
-
generator = self.get_generator(params[0].device, seed)
|
|
139
|
+
beta1, beta2, averaging, block_size = unpack_dicts(settings, 'beta1', 'beta2', 'averaging', 'block_size', cls=NumberList)
|
|
177
140
|
|
|
178
|
-
|
|
179
|
-
'beta1', 'beta2', 'eps', 'averaging', 'block_size', "hessian_power", cls=NumberList)
|
|
141
|
+
exp_avg, D_exp_avg_sq = unpack_states(states, params, 'exp_avg', 'D_exp_avg_sq', cls=TensorList)
|
|
180
142
|
|
|
181
|
-
|
|
143
|
+
# ---------------------------- hutchinson hessian ---------------------------- #
|
|
144
|
+
fs = settings[0]
|
|
145
|
+
step = self.increment_counter("step", start=0) # 0 on 1st update
|
|
146
|
+
update_freq = fs['update_freq']
|
|
182
147
|
|
|
183
|
-
step
|
|
184
|
-
|
|
148
|
+
if step % update_freq == 0:
|
|
149
|
+
self.increment_counter("num_Ds", start=1)
|
|
150
|
+
|
|
151
|
+
D, _ = objective.hutchinson_hessian(
|
|
152
|
+
rgrad = None,
|
|
153
|
+
at_x0 = True,
|
|
154
|
+
n_samples = fs['n_samples'],
|
|
155
|
+
distribution = fs['distribution'],
|
|
156
|
+
hvp_method = fs['hvp_method'],
|
|
157
|
+
h = fs['h'],
|
|
158
|
+
zHz = fs["zHz"],
|
|
159
|
+
generator = self.get_generator(params[0].device, fs["seed"]),
|
|
160
|
+
)
|
|
185
161
|
|
|
186
|
-
|
|
187
|
-
|
|
162
|
+
D = TensorList(D).zipmap_args(_block_average, block_size, averaging)
|
|
163
|
+
D_exp_avg_sq.mul_(beta2).addcmul_(D, D, value=1-beta2)
|
|
188
164
|
|
|
189
|
-
|
|
190
|
-
|
|
165
|
+
# --------------------------------- momentum --------------------------------- #
|
|
166
|
+
tensors = objective.get_updates() # do this after hutchinson to not disturb autograd
|
|
167
|
+
exp_avg.lerp_(tensors, 1-beta1)
|
|
191
168
|
|
|
192
|
-
rgrad=None
|
|
193
|
-
for i in range(n_samples):
|
|
194
|
-
u = [_rademacher_like(p, generator=generator) for p in params]
|
|
195
169
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
170
|
+
@torch.no_grad
|
|
171
|
+
def apply_states(self, objective, states, settings):
|
|
172
|
+
params = objective.params
|
|
199
173
|
|
|
200
|
-
|
|
201
|
-
|
|
174
|
+
beta1, beta2, eps, hessian_power = unpack_dicts(settings, 'beta1', 'beta2', 'eps', 'hessian_power', cls=NumberList)
|
|
175
|
+
exp_avg, D_exp_avg_sq = unpack_states(states, params, 'exp_avg', 'D_exp_avg_sq', cls=TensorList)
|
|
202
176
|
|
|
203
|
-
|
|
204
|
-
|
|
177
|
+
# ---------------------------------- debias ---------------------------------- #
|
|
178
|
+
if settings[0]["debias"]:
|
|
179
|
+
bias_correction1 = 1.0 - (beta1 ** (self.global_state["step"] + 1))
|
|
180
|
+
bias_correction2 = 1.0 - (beta2 ** self.global_state["num_Ds"])
|
|
181
|
+
exp_avg = exp_avg / bias_correction1
|
|
182
|
+
D_exp_avg_sq = D_exp_avg_sq / bias_correction2
|
|
205
183
|
|
|
206
|
-
D = TensorList(D).zipmap_args(_block_average, block_size, averaging)
|
|
207
184
|
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
tensors=
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
update_freq=update_freq,
|
|
220
|
-
eps=eps,
|
|
221
|
-
hessian_power=hessian_power,
|
|
222
|
-
step=step,
|
|
223
|
-
)
|
|
224
|
-
return var
|
|
185
|
+
# -------------------------------- transforms -------------------------------- #
|
|
186
|
+
exp_avg = TensorList(self.inner_step_tensors(
|
|
187
|
+
"exp_avg", tensors=exp_avg, clone=True, objective=objective, must_exist=False))
|
|
188
|
+
|
|
189
|
+
D_exp_avg_sq = TensorList(self.inner_step_tensors(
|
|
190
|
+
"D_exp_avg_sq", tensors=D_exp_avg_sq, clone=True, objective=objective, must_exist=False))
|
|
191
|
+
|
|
192
|
+
# ------------------------------ compute update ------------------------------ #
|
|
193
|
+
denom = D_exp_avg_sq.lazy_pow(hessian_power / 2) + eps
|
|
194
|
+
objective.updates = exp_avg / denom
|
|
195
|
+
return objective
|
|
@@ -1,48 +1,11 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
from functools import partial
|
|
3
|
-
|
|
4
1
|
import torch
|
|
5
2
|
|
|
6
|
-
from ...core import
|
|
3
|
+
from ...core import Chainable, Module, TensorTransform
|
|
7
4
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
8
|
-
from ..
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def adam_(
|
|
16
|
-
tensors: TensorList,
|
|
17
|
-
exp_avg_: TensorList,
|
|
18
|
-
exp_avg_sq_: TensorList,
|
|
19
|
-
alpha: float | NumberList,
|
|
20
|
-
beta1: float | NumberList,
|
|
21
|
-
beta2: float | NumberList,
|
|
22
|
-
eps: float | NumberList,
|
|
23
|
-
step: int,
|
|
24
|
-
pow: float = 2,
|
|
25
|
-
debiased: bool = True,
|
|
26
|
-
max_exp_avg_sq_: TensorList | None = None,
|
|
27
|
-
|
|
28
|
-
# inner args
|
|
29
|
-
inner: Module | None = None,
|
|
30
|
-
params: list[torch.Tensor] | None = None,
|
|
31
|
-
grads: list[torch.Tensor] | None = None,
|
|
32
|
-
):
|
|
33
|
-
"""Returns new tensors."""
|
|
34
|
-
sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
|
|
35
|
-
debiased=False,step=step,pow=pow)
|
|
36
|
-
|
|
37
|
-
if inner is not None:
|
|
38
|
-
assert params is not None
|
|
39
|
-
tensors = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
|
|
40
|
-
|
|
41
|
-
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
42
|
-
if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
43
|
-
return (exp_avg_.lazy_mul(alpha) / sqrt_exp_avg_sq.add_(eps))
|
|
44
|
-
|
|
45
|
-
class Adam(Transform):
|
|
5
|
+
from ..opt_utils import debiased_step_size
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Adam(TensorTransform):
|
|
46
9
|
"""Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.
|
|
47
10
|
|
|
48
11
|
This implementation is identical to :code:`torch.optim.Adam`.
|
|
@@ -54,7 +17,7 @@ class Adam(Transform):
|
|
|
54
17
|
alpha (float, optional): learning rate. Defaults to 1.
|
|
55
18
|
amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
|
|
56
19
|
pow (float, optional): power used in second momentum power and root. Defaults to 2.
|
|
57
|
-
|
|
20
|
+
debias (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
|
|
58
21
|
"""
|
|
59
22
|
def __init__(
|
|
60
23
|
self,
|
|
@@ -63,45 +26,59 @@ class Adam(Transform):
|
|
|
63
26
|
eps: float = 1e-8,
|
|
64
27
|
amsgrad: bool = False,
|
|
65
28
|
alpha: float = 1.,
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
29
|
+
debias: bool = True,
|
|
30
|
+
|
|
31
|
+
exp_avg_tfm: Chainable | None = None,
|
|
32
|
+
exp_avg_sq_tfm: Chainable | None = None,
|
|
69
33
|
):
|
|
70
|
-
defaults=
|
|
71
|
-
|
|
34
|
+
defaults = locals().copy()
|
|
35
|
+
del defaults['self'], defaults["exp_avg_tfm"], defaults["exp_avg_sq_tfm"]
|
|
36
|
+
super().__init__(defaults)
|
|
72
37
|
|
|
73
|
-
|
|
38
|
+
self.set_child('exp_avg', exp_avg_tfm)
|
|
39
|
+
self.set_child('exp_avg_sq', exp_avg_sq_tfm)
|
|
74
40
|
|
|
75
41
|
@torch.no_grad
|
|
76
|
-
def
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
42
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
43
|
+
self.increment_counter("step", start=0)
|
|
44
|
+
beta1, beta2 = unpack_dicts(settings, 'beta1','beta2', cls=NumberList)
|
|
45
|
+
|
|
46
|
+
# ----------------------------- initialize states ---------------------------- #
|
|
47
|
+
if settings[0]["amsgrad"]:
|
|
48
|
+
exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(
|
|
49
|
+
states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
|
|
84
50
|
else:
|
|
85
51
|
exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
|
|
86
52
|
max_exp_avg_sq = None
|
|
87
53
|
|
|
54
|
+
# ------------------------------ update moments ------------------------------ #
|
|
55
|
+
exp_avg.lerp_(tensors, weight=1-beta1)
|
|
56
|
+
exp_avg_sq.mul_(beta2).addcmul_(tensors, tensors, value=1-beta2)
|
|
57
|
+
|
|
58
|
+
if max_exp_avg_sq is not None:
|
|
59
|
+
max_exp_avg_sq.maximum_(exp_avg_sq)
|
|
60
|
+
|
|
61
|
+
@torch.no_grad
|
|
62
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
63
|
+
step = self.global_state["step"] # 0 on 1st step
|
|
64
|
+
fs = settings[0]
|
|
65
|
+
|
|
66
|
+
if fs["amsgrad"]: key = "max_exp_avg_sq"
|
|
67
|
+
else: key = "exp_avg_sq"
|
|
68
|
+
exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', key, cls=TensorList)
|
|
69
|
+
beta1, beta2, alpha, eps = unpack_dicts(settings, 'beta1', 'beta2', 'alpha', 'eps', cls=NumberList)
|
|
70
|
+
|
|
71
|
+
# -------------------------------- transforms -------------------------------- #
|
|
72
|
+
exp_avg = TensorList(self.inner_step_tensors(
|
|
73
|
+
"exp_avg", tensors=exp_avg, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
|
|
74
|
+
|
|
75
|
+
exp_avg_sq = TensorList(self.inner_step_tensors(
|
|
76
|
+
"exp_avg_sq", tensors=exp_avg_sq, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
|
|
77
|
+
|
|
78
|
+
# ---------------------------------- debias ---------------------------------- #
|
|
79
|
+
if fs["debias"]:
|
|
80
|
+
alpha = debiased_step_size((step + 1), beta1=beta1, beta2=beta2, alpha=alpha)
|
|
81
|
+
exp_avg = exp_avg * alpha
|
|
88
82
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
exp_avg_=exp_avg,
|
|
92
|
-
exp_avg_sq_=exp_avg_sq,
|
|
93
|
-
alpha=alpha,
|
|
94
|
-
beta1=beta1,
|
|
95
|
-
beta2=beta2,
|
|
96
|
-
eps=eps,
|
|
97
|
-
step=step,
|
|
98
|
-
pow=pow,
|
|
99
|
-
debiased=debiased,
|
|
100
|
-
max_exp_avg_sq_=max_exp_avg_sq,
|
|
101
|
-
|
|
102
|
-
# inner args
|
|
103
|
-
inner=self.children.get("inner", None),
|
|
104
|
-
params=params,
|
|
105
|
-
grads=grads,
|
|
106
|
-
|
|
107
|
-
)
|
|
83
|
+
# ---------------------------------- update ---------------------------------- #
|
|
84
|
+
return exp_avg / exp_avg_sq.sqrt().add_(eps)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from ...core import
|
|
3
|
+
from ...core import TensorTransform, Chainable
|
|
4
4
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
5
5
|
|
|
6
|
-
def
|
|
6
|
+
def adan_update_(
|
|
7
7
|
g: TensorList,
|
|
8
8
|
g_prev_: TensorList,
|
|
9
9
|
m_: TensorList, # exponential moving average
|
|
@@ -12,10 +12,8 @@ def adan_(
|
|
|
12
12
|
beta1: float | NumberList,
|
|
13
13
|
beta2: float | NumberList,
|
|
14
14
|
beta3: float | NumberList,
|
|
15
|
-
eps: float | NumberList,
|
|
16
15
|
step: int,
|
|
17
16
|
):
|
|
18
|
-
"""Returns new tensors"""
|
|
19
17
|
m_.lerp_(g, 1 - beta1)
|
|
20
18
|
|
|
21
19
|
if step == 1:
|
|
@@ -26,7 +24,18 @@ def adan_(
|
|
|
26
24
|
term = g + beta2 * diff
|
|
27
25
|
|
|
28
26
|
n_.mul_(beta3).addcmul_(term, term, value=(1 - beta3))
|
|
27
|
+
g_prev_.copy_(g)
|
|
29
28
|
|
|
29
|
+
def adan_apply_(
|
|
30
|
+
m_: TensorList, # exponential moving average
|
|
31
|
+
v_: TensorList, # exponential moving average of gradient differences
|
|
32
|
+
n_: TensorList, # kinda like squared momentum
|
|
33
|
+
beta1: float | NumberList,
|
|
34
|
+
beta2: float | NumberList,
|
|
35
|
+
beta3: float | NumberList,
|
|
36
|
+
eps: float | NumberList,
|
|
37
|
+
step: int,
|
|
38
|
+
):
|
|
30
39
|
m = m_ / (1.0 - beta1**step)
|
|
31
40
|
v = v_ / (1.0 - beta2**step)
|
|
32
41
|
n = n_ / (1.0 - beta3**step)
|
|
@@ -35,13 +44,12 @@ def adan_(
|
|
|
35
44
|
num = m + beta2 * v
|
|
36
45
|
|
|
37
46
|
update = num.div_(denom)
|
|
38
|
-
g_prev_.copy_(g)
|
|
39
47
|
|
|
40
48
|
return update
|
|
41
49
|
|
|
42
50
|
|
|
43
51
|
|
|
44
|
-
class Adan(
|
|
52
|
+
class Adan(TensorTransform):
|
|
45
53
|
"""Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677
|
|
46
54
|
|
|
47
55
|
Args:
|
|
@@ -49,18 +57,17 @@ class Adan(Transform):
|
|
|
49
57
|
beta2 (float, optional): momentum for gradient differences. Defaults to 0.92.
|
|
50
58
|
beta3 (float, optional): thrid (squared) momentum. Defaults to 0.99.
|
|
51
59
|
eps (float, optional): epsilon. Defaults to 1e-8.
|
|
52
|
-
use_n_prev (bool, optional):
|
|
53
|
-
whether to use previous gradient differences momentum.
|
|
54
60
|
|
|
55
61
|
Example:
|
|
56
62
|
```python
|
|
57
|
-
opt = tz.
|
|
63
|
+
opt = tz.Optimizer(
|
|
58
64
|
model.parameters(),
|
|
59
65
|
tz.m.Adan(),
|
|
60
66
|
tz.m.LR(1e-3),
|
|
61
67
|
)
|
|
68
|
+
```
|
|
62
69
|
Reference:
|
|
63
|
-
Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence
|
|
70
|
+
[Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence](https://arxiv.org/abs/2208.06677).
|
|
64
71
|
"""
|
|
65
72
|
def __init__(
|
|
66
73
|
self,
|
|
@@ -68,29 +75,41 @@ class Adan(Transform):
|
|
|
68
75
|
beta2: float = 0.92,
|
|
69
76
|
beta3: float = 0.99,
|
|
70
77
|
eps: float = 1e-8,
|
|
78
|
+
|
|
79
|
+
m_tfm: Chainable | None = None,
|
|
80
|
+
v_tfm: Chainable | None = None,
|
|
81
|
+
n_tfm: Chainable | None = None,
|
|
71
82
|
):
|
|
72
|
-
defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps)
|
|
83
|
+
defaults=dict(beta1=beta1, beta2=beta2, beta3=beta3, eps=eps)
|
|
73
84
|
super().__init__(defaults, uses_grad=False)
|
|
74
85
|
|
|
86
|
+
self.set_child("m", m_tfm)
|
|
87
|
+
self.set_child("v", v_tfm)
|
|
88
|
+
self.set_child("n", n_tfm)
|
|
89
|
+
|
|
75
90
|
@torch.no_grad
|
|
76
|
-
def
|
|
91
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
77
92
|
tensors = TensorList(tensors)
|
|
78
|
-
step = self.
|
|
79
|
-
|
|
80
|
-
beta1,beta2,beta3
|
|
81
|
-
g_prev, m, v, n = unpack_states(states, tensors, 'g_prev','m','v','n', cls=TensorList)
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
|
|
93
|
+
step = self.increment_counter("step", start=0)
|
|
94
|
+
|
|
95
|
+
beta1, beta2, beta3 = unpack_dicts(settings, 'beta1','beta2','beta3', cls=NumberList)
|
|
96
|
+
g_prev, m, v, n = unpack_states(states, tensors, 'g_prev', 'm', 'v', 'n', cls=TensorList)
|
|
97
|
+
|
|
98
|
+
adan_update_(g=tensors, g_prev_=g_prev, m_=m, v_=v, n_=n, beta1=beta1, beta2=beta2, beta3=beta3, step=step+1)
|
|
99
|
+
|
|
100
|
+
@torch.no_grad
|
|
101
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
102
|
+
tensors = TensorList(tensors)
|
|
103
|
+
step = self.global_state["step"] # 0 on 1st step
|
|
104
|
+
|
|
105
|
+
beta1, beta2, beta3, eps = unpack_dicts(settings, 'beta1','beta2','beta3', 'eps', cls=NumberList)
|
|
106
|
+
m, v, n = unpack_states(states, tensors, 'm', 'v', 'n')
|
|
107
|
+
|
|
108
|
+
# -------------------------------- transforms -------------------------------- #
|
|
109
|
+
m = TensorList(self.inner_step_tensors("m", m, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
|
|
110
|
+
v = TensorList(self.inner_step_tensors("v", v, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
|
|
111
|
+
n = TensorList(self.inner_step_tensors("n", n, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
|
|
112
|
+
|
|
113
|
+
# ---------------------------------- update ---------------------------------- #
|
|
114
|
+
return adan_apply_(m_=m, v_=v, n_=n, beta1=beta1, beta2=beta2, beta3=beta3, eps=eps, step=step+1)
|
|
115
|
+
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from ...core import
|
|
2
|
+
from ...core import TensorTransform
|
|
3
3
|
from ...utils import TensorList, unpack_dicts, unpack_states
|
|
4
4
|
|
|
5
5
|
|
|
@@ -16,10 +16,10 @@ def adaptive_heavy_ball(f, f_star, f_prev, g: TensorList, g_prev: TensorList, p:
|
|
|
16
16
|
return (1 + m) * h * g - m*(p-p_prev)
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
class AdaptiveHeavyBall(
|
|
19
|
+
class AdaptiveHeavyBall(TensorTransform):
|
|
20
20
|
"""Adaptive heavy ball from https://hal.science/hal-04832983v1/file/OJMO_2024__5__A7_0.pdf.
|
|
21
21
|
|
|
22
|
-
|
|
22
|
+
Suitable for quadratic objectives with known f* (loss at minimum).
|
|
23
23
|
|
|
24
24
|
note:
|
|
25
25
|
The step size is determined by the algorithm, so learning rate modules shouldn't be used.
|
|
@@ -30,25 +30,30 @@ class AdaptiveHeavyBall(Transform):
|
|
|
30
30
|
"""
|
|
31
31
|
def __init__(self, f_star: float = 0):
|
|
32
32
|
defaults = dict(f_star=f_star)
|
|
33
|
-
super().__init__(defaults,
|
|
33
|
+
super().__init__(defaults, uses_loss=True)
|
|
34
34
|
|
|
35
35
|
@torch.no_grad
|
|
36
|
-
def
|
|
36
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
37
37
|
assert loss is not None
|
|
38
38
|
tensors = TensorList(tensors)
|
|
39
|
-
f_star =
|
|
39
|
+
f_star = settings[0]['f_star']
|
|
40
40
|
|
|
41
41
|
f_prev = self.global_state.get('f_prev', None)
|
|
42
42
|
p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', init=[params,tensors], cls=TensorList)
|
|
43
43
|
|
|
44
|
+
# -------------------------------- first step -------------------------------- #
|
|
44
45
|
if f_prev is None:
|
|
45
46
|
self.global_state['f_prev'] = loss
|
|
46
47
|
h = 2*(loss - f_star) / tensors.dot(tensors)
|
|
47
48
|
return h * tensors
|
|
48
49
|
|
|
49
|
-
|
|
50
|
+
# ------------------------------- further steps ------------------------------ #
|
|
51
|
+
update = adaptive_heavy_ball(
|
|
52
|
+
f=loss, f_star=f_star, f_prev=f_prev, g=tensors, g_prev=g_prev, p=TensorList(params), p_prev=p_prev)
|
|
50
53
|
|
|
54
|
+
# --------------------------- store previous values -------------------------- #
|
|
51
55
|
self.global_state['f_prev'] = loss
|
|
52
56
|
p_prev.copy_(params)
|
|
53
57
|
g_prev.copy_(tensors)
|
|
58
|
+
|
|
54
59
|
return update
|
|
@@ -2,17 +2,18 @@ import math
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import
|
|
5
|
+
from ...core import TensorTransform
|
|
6
6
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
7
7
|
|
|
8
8
|
# i've verified, it is identical to official
|
|
9
9
|
# https://github.com/txping/AEGD/blob/master/aegd.py
|
|
10
|
+
# TODO: add a test
|
|
10
11
|
def aegd_(f: torch.Tensor | float, g: TensorList, r_: TensorList, c:float|NumberList=1, eta:float|NumberList=0.1) -> TensorList:
|
|
11
12
|
v = g / (2 * (f + c)**0.5)
|
|
12
13
|
r_ /= 1 + (v ** 2).mul_(2*eta) # update energy
|
|
13
14
|
return 2*eta * r_*v # pyright:ignore[reportReturnType]
|
|
14
15
|
|
|
15
|
-
class AEGD(
|
|
16
|
+
class AEGD(TensorTransform):
|
|
16
17
|
"""AEGD (Adaptive gradient descent with energy) from https://arxiv.org/abs/2010.05109#page=10.26.
|
|
17
18
|
|
|
18
19
|
Note:
|
|
@@ -20,28 +21,27 @@ class AEGD(Transform):
|
|
|
20
21
|
To avoid compounding learning rate mofications, remove the ``tz.m.LR`` module if you had it.
|
|
21
22
|
|
|
22
23
|
Args:
|
|
23
|
-
|
|
24
|
-
c (float, optional):
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
whether to use previous gradient differences momentum.
|
|
24
|
+
lr (float, optional): learning rate (default: 0.1)
|
|
25
|
+
c (float, optional): term added to the original objective function (default: 1)
|
|
26
|
+
|
|
27
|
+
Reference:
|
|
28
|
+
[Liu, Hailiang, and Xuping Tian. "AEGD: Adaptive gradient descent with energy." arXiv preprint arXiv:2010.05109 (2020).](https://arxiv.org/pdf/2010.05109)
|
|
29
29
|
"""
|
|
30
30
|
def __init__(
|
|
31
31
|
self,
|
|
32
32
|
lr: float = 0.1,
|
|
33
33
|
c: float = 1,
|
|
34
34
|
):
|
|
35
|
-
defaults=dict(c=c,lr=lr)
|
|
35
|
+
defaults = dict(c=c, lr=lr)
|
|
36
36
|
super().__init__(defaults, uses_loss=True)
|
|
37
37
|
|
|
38
38
|
@torch.no_grad
|
|
39
|
-
def
|
|
39
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
40
40
|
assert loss is not None
|
|
41
41
|
tensors = TensorList(tensors)
|
|
42
42
|
|
|
43
|
-
c,lr=unpack_dicts(settings, 'c','lr', cls=NumberList)
|
|
44
|
-
r = unpack_states(states, tensors, 'r', init=lambda t: torch.full_like(t, float(loss+c[0])**0.5), cls=TensorList)
|
|
43
|
+
c, lr = unpack_dicts(settings, 'c', 'lr', cls=NumberList)
|
|
44
|
+
r = unpack_states(states, tensors, 'r', init=lambda t: torch.full_like(t, float(loss + c[0])**0.5), cls=TensorList)
|
|
45
45
|
|
|
46
46
|
update = aegd_(
|
|
47
47
|
f=loss,
|