torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -1,288 +0,0 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
|
-
import math
|
|
3
|
-
from collections import deque
|
|
4
|
-
from typing import Literal, Any
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
from ...core import Chainable, TensorwisePreconditioner
|
|
8
|
-
from ...utils.linalg.matrix_funcs import matrix_power_eigh
|
|
9
|
-
from ...utils.linalg.svd import randomized_svd
|
|
10
|
-
from ...utils.linalg.qr import qr_householder
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class _Solver:
|
|
14
|
-
@abstractmethod
|
|
15
|
-
def update(self, history: deque[torch.Tensor], damping: float | None) -> tuple[Any, Any]:
|
|
16
|
-
"""returns stuff for apply"""
|
|
17
|
-
@abstractmethod
|
|
18
|
-
def apply(self, __g: torch.Tensor, __A:torch.Tensor, __B:torch.Tensor) -> torch.Tensor:
|
|
19
|
-
"""apply preconditioning to tensor"""
|
|
20
|
-
|
|
21
|
-
class _SVDSolver(_Solver):
|
|
22
|
-
def __init__(self, driver=None): self.driver=driver
|
|
23
|
-
def update(self, history, damping):
|
|
24
|
-
M_hist = torch.stack(tuple(history), dim=1)
|
|
25
|
-
device = None # driver is CUDA only
|
|
26
|
-
if self.driver is not None:
|
|
27
|
-
device = M_hist.device
|
|
28
|
-
M_hist = M_hist.cuda()
|
|
29
|
-
|
|
30
|
-
try:
|
|
31
|
-
U, S, _ = torch.linalg.svd(M_hist, full_matrices=False, driver=self.driver) # pylint:disable=not-callable
|
|
32
|
-
|
|
33
|
-
if self.driver is not None:
|
|
34
|
-
U = U.to(device); S = S.to(device)
|
|
35
|
-
|
|
36
|
-
if damping is not None and damping != 0: S.add_(damping)
|
|
37
|
-
return U, S
|
|
38
|
-
|
|
39
|
-
except torch.linalg.LinAlgError:
|
|
40
|
-
return None, None
|
|
41
|
-
|
|
42
|
-
def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
|
|
43
|
-
Utg = (U.T @ g).div_(S)
|
|
44
|
-
return U @ Utg
|
|
45
|
-
|
|
46
|
-
class _SVDLowRankSolver(_Solver):
|
|
47
|
-
def __init__(self, q: int = 6, niter: int = 2): self.q, self.niter = q, niter
|
|
48
|
-
def update(self, history, damping):
|
|
49
|
-
M_hist = torch.stack(tuple(history), dim=1)
|
|
50
|
-
try:
|
|
51
|
-
U, S, _ = torch.svd_lowrank(M_hist, q=self.q, niter=self.niter)
|
|
52
|
-
if damping is not None and damping != 0: S.add_(damping)
|
|
53
|
-
return U, S
|
|
54
|
-
except torch.linalg.LinAlgError:
|
|
55
|
-
return None, None
|
|
56
|
-
|
|
57
|
-
def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
|
|
58
|
-
Utg = (U.T @ g).div_(S)
|
|
59
|
-
return U @ Utg
|
|
60
|
-
|
|
61
|
-
class _RandomizedSVDSolver(_Solver):
|
|
62
|
-
def __init__(self, k: int = 3, driver: str | None = 'gesvda'):
|
|
63
|
-
self.driver = driver
|
|
64
|
-
self.k = k
|
|
65
|
-
|
|
66
|
-
def update(self, history, damping):
|
|
67
|
-
M_hist = torch.stack(tuple(history), dim=1)
|
|
68
|
-
device = None # driver is CUDA only
|
|
69
|
-
if self.driver is not None:
|
|
70
|
-
device = M_hist.device
|
|
71
|
-
M_hist = M_hist.cuda()
|
|
72
|
-
|
|
73
|
-
try:
|
|
74
|
-
U, S, _ = randomized_svd(M_hist, k=self.k, driver=self.driver)
|
|
75
|
-
|
|
76
|
-
if self.driver is not None:
|
|
77
|
-
U = U.to(device); S = S.to(device)
|
|
78
|
-
|
|
79
|
-
if damping is not None and damping != 0: S.add_(damping)
|
|
80
|
-
return U, S
|
|
81
|
-
|
|
82
|
-
except torch.linalg.LinAlgError:
|
|
83
|
-
return None, None
|
|
84
|
-
|
|
85
|
-
def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
|
|
86
|
-
Utg = (U.T @ g).div_(S)
|
|
87
|
-
return U @ Utg
|
|
88
|
-
|
|
89
|
-
class _QRDiagonalSolver(_Solver):
|
|
90
|
-
def __init__(self, sqrt=True): self.sqrt = sqrt
|
|
91
|
-
def update(self, history, damping):
|
|
92
|
-
M_hist = torch.stack(tuple(history), dim=1)
|
|
93
|
-
try:
|
|
94
|
-
Q, R = torch.linalg.qr(M_hist, mode='reduced') # pylint:disable=not-callable
|
|
95
|
-
R_diag = R.diag().abs()
|
|
96
|
-
if damping is not None and damping != 0: R_diag.add_(damping)
|
|
97
|
-
if self.sqrt: R_diag.sqrt_()
|
|
98
|
-
return Q, R_diag
|
|
99
|
-
except torch.linalg.LinAlgError:
|
|
100
|
-
return None, None
|
|
101
|
-
|
|
102
|
-
def apply(self, g: torch.Tensor, Q: torch.Tensor, R_diag: torch.Tensor):
|
|
103
|
-
Qtg = (Q.T @ g).div_(R_diag)
|
|
104
|
-
return Q @ Qtg
|
|
105
|
-
|
|
106
|
-
class _QRSolver(_Solver):
|
|
107
|
-
def __init__(self, sqrt=True): self.sqrt = sqrt
|
|
108
|
-
def update(self, history, damping):
|
|
109
|
-
M_hist = torch.stack(tuple(history), dim=1)
|
|
110
|
-
try:
|
|
111
|
-
# Q: d x k, R: k x k
|
|
112
|
-
Q, R = torch.linalg.qr(M_hist, mode='reduced') # pylint:disable=not-callable
|
|
113
|
-
A = R @ R.T
|
|
114
|
-
if damping is not None and damping != 0: A.diagonal(dim1=-2, dim2=-1).add_(damping)
|
|
115
|
-
if self.sqrt: A = matrix_power_eigh(A, 0.5)
|
|
116
|
-
return Q, A
|
|
117
|
-
except (torch.linalg.LinAlgError):
|
|
118
|
-
return None,None
|
|
119
|
-
|
|
120
|
-
def apply(self, g: torch.Tensor, Q: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
|
|
121
|
-
g_proj = Q.T @ g
|
|
122
|
-
y, _ = torch.linalg.solve_ex(A, g_proj) # pylint:disable=not-callable
|
|
123
|
-
return Q @ y
|
|
124
|
-
|
|
125
|
-
class _QRHouseholderSolver(_Solver):
|
|
126
|
-
def __init__(self, sqrt=True): self.sqrt = sqrt
|
|
127
|
-
def update(self, history, damping):
|
|
128
|
-
M_hist = torch.stack(tuple(history), dim=1)
|
|
129
|
-
try:
|
|
130
|
-
# Q: d x k, R: k x k
|
|
131
|
-
Q, R = qr_householder(M_hist, mode='reduced') # pylint:disable=not-callable
|
|
132
|
-
A = R @ R.T
|
|
133
|
-
if damping is not None and damping != 0: A.diagonal(dim1=-2, dim2=-1).add_(damping)
|
|
134
|
-
if self.sqrt: A = matrix_power_eigh(A, 0.5)
|
|
135
|
-
return Q, A
|
|
136
|
-
except (torch.linalg.LinAlgError):
|
|
137
|
-
return None,None
|
|
138
|
-
|
|
139
|
-
def apply(self, g: torch.Tensor, Q: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
|
|
140
|
-
g_proj = Q.T @ g
|
|
141
|
-
y, _ = torch.linalg.solve_ex(A, g_proj) # pylint:disable=not-callable
|
|
142
|
-
return Q @ y
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
class _EighSolver(_Solver):
|
|
146
|
-
def __init__(self, sqrt=True):
|
|
147
|
-
self.sqrt = sqrt
|
|
148
|
-
|
|
149
|
-
def update(self, history, damping):
|
|
150
|
-
M_hist = torch.stack(tuple(history), dim=1)
|
|
151
|
-
grams = M_hist @ M_hist.T # (d, d)
|
|
152
|
-
if damping is not None and damping != 0: grams.diagonal(dim1=-2, dim2=-1).add_(damping)
|
|
153
|
-
try:
|
|
154
|
-
L, Q = torch.linalg.eigh(grams) # L: (d,), Q: (d, d) # pylint:disable=not-callable
|
|
155
|
-
L = L.abs().clamp_(min=1e-12)
|
|
156
|
-
if self.sqrt: L = L.sqrt()
|
|
157
|
-
return Q, L
|
|
158
|
-
except torch.linalg.LinAlgError:
|
|
159
|
-
return None, None
|
|
160
|
-
|
|
161
|
-
def apply(self, g: torch.Tensor, Q: torch.Tensor, L: torch.Tensor) -> torch.Tensor:
|
|
162
|
-
Qtg = (Q.T @ g).div_(L)
|
|
163
|
-
return Q @ Qtg
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
SOLVERS = {
|
|
167
|
-
"svd": _SVDSolver(), # fallbacks on "gesvd" which basically takes ages or just hangs completely
|
|
168
|
-
"svd_gesvdj": _SVDSolver("gesvdj"), # no fallback on slow "gesvd"
|
|
169
|
-
"svd_gesvda": _SVDSolver("gesvda"), # approximate method for wide matrices, sometimes better sometimes worse but faster
|
|
170
|
-
"svd_lowrank": _SVDLowRankSolver(), # maybe need to tune parameters for this, with current ones its slower and worse
|
|
171
|
-
"randomized_svd2": _RandomizedSVDSolver(2),
|
|
172
|
-
"randomized_svd3": _RandomizedSVDSolver(3),
|
|
173
|
-
"randomized_svd4": _RandomizedSVDSolver(4),
|
|
174
|
-
"randomized_svd5": _RandomizedSVDSolver(5),
|
|
175
|
-
"eigh": _EighSolver(), # this is O(n**2) storage, but is this more accurate?
|
|
176
|
-
"qr": _QRSolver(),
|
|
177
|
-
"qr_householder": _QRHouseholderSolver(), # this is slower... but maybe it won't freeze? I think svd_gesvda is better
|
|
178
|
-
"qrdiag": _QRDiagonalSolver(),
|
|
179
|
-
}
|
|
180
|
-
|
|
181
|
-
def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
|
|
182
|
-
if (key not in state_) or (beta is None) or (not isinstance(value, torch.Tensor)): state_[key] = value
|
|
183
|
-
else:
|
|
184
|
-
if state_[key].shape != value.shape: state_[key] = value
|
|
185
|
-
else: state_[key].lerp_(value, 1-beta)
|
|
186
|
-
|
|
187
|
-
class SpectralPreconditioner(TensorwisePreconditioner):
|
|
188
|
-
"""Whitening preconditioner via SVD on history of past gradients or gradient differences scaled by parameter differences.
|
|
189
|
-
|
|
190
|
-
Args:
|
|
191
|
-
history_size (int, optional): number of past gradients to store for preconditioning. Defaults to 10.
|
|
192
|
-
update_freq (int, optional): how often to re-compute the preconditioner. Defaults to 1.
|
|
193
|
-
damping (float, optional): damping term, makes it closer to GD. Defaults to 1e-7.
|
|
194
|
-
order (int, optional):
|
|
195
|
-
whitening order, 1 approximates FIM (maybe), 2 - hessian (maybe), 3+ - god knows what.
|
|
196
|
-
solver (str, optional): what to use for whitening. Defaults to 'svd'.
|
|
197
|
-
A_beta (float | None, optional):
|
|
198
|
-
beta for U (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
|
|
199
|
-
B_beta (float | None, optional):
|
|
200
|
-
beta for S (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
|
|
201
|
-
interval (int, optional): How often to update history. Defaults to 1 (every step).
|
|
202
|
-
concat_params (bool, optional):
|
|
203
|
-
whether to apply preconditioning to each tensor (False, default) or to all tensors concatenated into a vector (True). Latter will be slower but captures interactions between layers. Defaults to True.
|
|
204
|
-
scale_first (bool, optional): makes first step small, usually not needed. Defaults to False.
|
|
205
|
-
inner (Chainable | None, optional): Inner modules applied after updating preconditioner and before applying it. Defaults to None.
|
|
206
|
-
"""
|
|
207
|
-
def __init__(
|
|
208
|
-
self,
|
|
209
|
-
history_size: int = 10,
|
|
210
|
-
update_freq: int = 1,
|
|
211
|
-
damping: float = 1e-12,
|
|
212
|
-
order: int = 1,
|
|
213
|
-
solver: Literal['svd', 'svd_gesvdj', 'svd_gesvda', 'svd_lowrank', 'eigh', 'qr', 'qrdiag', 'qr_householder'] | _Solver | str = 'svd_gesvda',
|
|
214
|
-
A_beta: float | None = None,
|
|
215
|
-
B_beta: float | None = None,
|
|
216
|
-
interval: int = 1,
|
|
217
|
-
concat_params: bool = False,
|
|
218
|
-
scale_first: bool = False,
|
|
219
|
-
inner: Chainable | None = None,
|
|
220
|
-
):
|
|
221
|
-
if isinstance(solver, str): solver = SOLVERS[solver]
|
|
222
|
-
# history is still updated each step so Precondition's update_freq has different meaning
|
|
223
|
-
defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, order=order, A_beta=A_beta, B_beta=B_beta, solver=solver)
|
|
224
|
-
super().__init__(defaults, uses_grad=False, concat_params=concat_params, scale_first=scale_first, inner=inner, update_freq=interval)
|
|
225
|
-
|
|
226
|
-
@torch.no_grad
|
|
227
|
-
def update_tensor(self, tensor, param, grad, state, settings):
|
|
228
|
-
order = settings['order']
|
|
229
|
-
history_size = settings['history_size']
|
|
230
|
-
update_freq = settings['update_freq']
|
|
231
|
-
damping = settings['damping']
|
|
232
|
-
A_beta = settings['A_beta']
|
|
233
|
-
B_beta = settings['B_beta']
|
|
234
|
-
solver: _Solver = settings['solver']
|
|
235
|
-
|
|
236
|
-
if 'history' not in state: state['history'] = deque(maxlen=history_size)
|
|
237
|
-
history = state['history']
|
|
238
|
-
|
|
239
|
-
if order == 1: history.append(tensor.clone().view(-1))
|
|
240
|
-
else:
|
|
241
|
-
|
|
242
|
-
# if order=2, history is of gradient differences, order 3 is differences between differences, etc
|
|
243
|
-
# normalized by parameter differences
|
|
244
|
-
cur_p = param.clone()
|
|
245
|
-
cur_g = tensor.clone()
|
|
246
|
-
for i in range(1, order):
|
|
247
|
-
if f'prev_g_{i}' not in state:
|
|
248
|
-
state[f'prev_p_{i}'] = cur_p
|
|
249
|
-
state[f'prev_g_{i}'] = cur_g
|
|
250
|
-
break
|
|
251
|
-
|
|
252
|
-
s_k = cur_p - state[f'prev_p_{i}']
|
|
253
|
-
y_k = cur_g - state[f'prev_g_{i}']
|
|
254
|
-
state[f'prev_p_{i}'] = cur_p
|
|
255
|
-
state[f'prev_g_{i}'] = cur_g
|
|
256
|
-
cur_p = s_k
|
|
257
|
-
cur_g = y_k
|
|
258
|
-
|
|
259
|
-
if i == order - 1:
|
|
260
|
-
cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=1e-8) # pylint:disable=not-callable
|
|
261
|
-
history.append(cur_g.view(-1))
|
|
262
|
-
|
|
263
|
-
step = state.get('step', 0)
|
|
264
|
-
if step % update_freq == 0 and len(history) != 0:
|
|
265
|
-
A, B = solver.update(history, damping=damping)
|
|
266
|
-
maybe_lerp_(state, A_beta, 'A', A)
|
|
267
|
-
maybe_lerp_(state, B_beta, 'B', B)
|
|
268
|
-
|
|
269
|
-
if len(history) != 0:
|
|
270
|
-
state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)
|
|
271
|
-
|
|
272
|
-
@torch.no_grad
|
|
273
|
-
def apply_tensor(self, tensor, param, grad, state, settings):
|
|
274
|
-
history_size = settings['history_size']
|
|
275
|
-
solver: _Solver = settings['solver']
|
|
276
|
-
|
|
277
|
-
A = state.get('A', None)
|
|
278
|
-
if A is None:
|
|
279
|
-
# make a conservative step to avoid issues due to different GD scaling
|
|
280
|
-
return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
|
|
281
|
-
|
|
282
|
-
B = state['B']
|
|
283
|
-
update = solver.apply(tensor.view(-1), A, B).view_as(tensor)
|
|
284
|
-
|
|
285
|
-
n = len(state['history'])
|
|
286
|
-
if n != history_size: update.mul_(n/history_size)
|
|
287
|
-
return update
|
|
288
|
-
|
|
@@ -1,111 +0,0 @@
|
|
|
1
|
-
# idea https://arxiv.org/pdf/2212.09841
|
|
2
|
-
import warnings
|
|
3
|
-
from collections.abc import Callable
|
|
4
|
-
from functools import partial
|
|
5
|
-
from typing import Literal
|
|
6
|
-
|
|
7
|
-
import torch
|
|
8
|
-
|
|
9
|
-
from ...core import Chainable, Module, apply
|
|
10
|
-
from ...utils import TensorList, vec_to_tensors
|
|
11
|
-
from ...utils.derivatives import (
|
|
12
|
-
hessian_list_to_mat,
|
|
13
|
-
hessian_mat,
|
|
14
|
-
hvp,
|
|
15
|
-
hvp_fd_central,
|
|
16
|
-
hvp_fd_forward,
|
|
17
|
-
jacobian_and_hessian_wrt,
|
|
18
|
-
)
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class StructuredNewton(Module):
|
|
22
|
-
"""TODO
|
|
23
|
-
Args:
|
|
24
|
-
structure (str, optional): structure.
|
|
25
|
-
reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
|
|
26
|
-
hvp_method (str):
|
|
27
|
-
how to calculate hvp_method. Defaults to "autograd".
|
|
28
|
-
inner (Chainable | None, optional): inner modules. Defaults to None.
|
|
29
|
-
|
|
30
|
-
"""
|
|
31
|
-
def __init__(
|
|
32
|
-
self,
|
|
33
|
-
structure: Literal[
|
|
34
|
-
"diagonal",
|
|
35
|
-
"diagonal1",
|
|
36
|
-
"diagonal_abs",
|
|
37
|
-
"tridiagonal",
|
|
38
|
-
"circulant",
|
|
39
|
-
"toeplitz",
|
|
40
|
-
"toeplitz_like",
|
|
41
|
-
"hankel",
|
|
42
|
-
"rank1",
|
|
43
|
-
"rank2", # any rank
|
|
44
|
-
]
|
|
45
|
-
| str = "diagonal",
|
|
46
|
-
reg: float = 1e-6,
|
|
47
|
-
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
48
|
-
h: float = 1e-3,
|
|
49
|
-
inner: Chainable | None = None,
|
|
50
|
-
):
|
|
51
|
-
defaults = dict(reg=reg, hvp_method=hvp_method, structure=structure, h=h)
|
|
52
|
-
super().__init__(defaults)
|
|
53
|
-
|
|
54
|
-
if inner is not None:
|
|
55
|
-
self.set_child('inner', inner)
|
|
56
|
-
|
|
57
|
-
@torch.no_grad
|
|
58
|
-
def step(self, vars):
|
|
59
|
-
params = TensorList(vars.params)
|
|
60
|
-
closure = vars.closure
|
|
61
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
62
|
-
|
|
63
|
-
settings = self.settings[params[0]]
|
|
64
|
-
reg = settings['reg']
|
|
65
|
-
hvp_method = settings['hvp_method']
|
|
66
|
-
structure = settings['structure']
|
|
67
|
-
h = settings['h']
|
|
68
|
-
|
|
69
|
-
# ------------------------ calculate grad and hessian ------------------------ #
|
|
70
|
-
if hvp_method == 'autograd':
|
|
71
|
-
grad = vars.get_grad(create_graph=True)
|
|
72
|
-
def Hvp_fn1(x):
|
|
73
|
-
return hvp(params, grad, x, retain_graph=True)
|
|
74
|
-
Hvp_fn = Hvp_fn1
|
|
75
|
-
|
|
76
|
-
elif hvp_method == 'forward':
|
|
77
|
-
grad = vars.get_grad()
|
|
78
|
-
def Hvp_fn2(x):
|
|
79
|
-
return hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1]
|
|
80
|
-
Hvp_fn = Hvp_fn2
|
|
81
|
-
|
|
82
|
-
elif hvp_method == 'central':
|
|
83
|
-
grad = vars.get_grad()
|
|
84
|
-
def Hvp_fn3(x):
|
|
85
|
-
return hvp_fd_central(closure, params, x, h=h, normalize=True)[1]
|
|
86
|
-
Hvp_fn = Hvp_fn3
|
|
87
|
-
|
|
88
|
-
else: raise ValueError(hvp_method)
|
|
89
|
-
|
|
90
|
-
# -------------------------------- inner step -------------------------------- #
|
|
91
|
-
update = vars.get_update()
|
|
92
|
-
if 'inner' in self.children:
|
|
93
|
-
update = apply(self.children['inner'], update, params=params, grads=grad, vars=vars)
|
|
94
|
-
|
|
95
|
-
# hessian
|
|
96
|
-
if structure.startswith('diagonal'):
|
|
97
|
-
H = Hvp_fn([torch.ones_like(p) for p in params])
|
|
98
|
-
if structure == 'diagonal1': torch._foreach_clamp_min_(H, 1)
|
|
99
|
-
if structure == 'diagonal_abs': torch._foreach_abs_(H)
|
|
100
|
-
torch._foreach_add_(H, reg)
|
|
101
|
-
torch._foreach_div_(update, H)
|
|
102
|
-
vars.update = update
|
|
103
|
-
return vars
|
|
104
|
-
|
|
105
|
-
# hessian
|
|
106
|
-
raise NotImplementedError(structure)
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
@@ -1,136 +0,0 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
from functools import partial
|
|
3
|
-
from typing import Literal
|
|
4
|
-
from collections.abc import Callable
|
|
5
|
-
import torch
|
|
6
|
-
|
|
7
|
-
from ...core import Chainable, apply, Module
|
|
8
|
-
from ...utils import vec_to_tensors, TensorList
|
|
9
|
-
from ...utils.derivatives import (
|
|
10
|
-
hessian_list_to_mat,
|
|
11
|
-
hessian_mat,
|
|
12
|
-
jacobian_and_hessian_wrt,
|
|
13
|
-
)
|
|
14
|
-
from ..second_order.newton import lu_solve, cholesky_solve, least_squares_solve
|
|
15
|
-
|
|
16
|
-
def tropical_sum(x, dim): return torch.amax(x, dim=dim)
|
|
17
|
-
def tropical_mul(x, y): return x+y
|
|
18
|
-
|
|
19
|
-
def tropical_matmul(x: torch.Tensor, y: torch.Tensor):
|
|
20
|
-
# this imlements matmul by calling mul and sum
|
|
21
|
-
|
|
22
|
-
x_squeeze = False
|
|
23
|
-
y_squeeze = False
|
|
24
|
-
|
|
25
|
-
if x.ndim == 1:
|
|
26
|
-
x_squeeze = True
|
|
27
|
-
x = x.unsqueeze(0)
|
|
28
|
-
|
|
29
|
-
if y.ndim == 1:
|
|
30
|
-
y_squeeze = True
|
|
31
|
-
y = y.unsqueeze(1)
|
|
32
|
-
|
|
33
|
-
res = tropical_sum(tropical_mul(x.unsqueeze(-1), y.unsqueeze(-3)), dim = -2)
|
|
34
|
-
|
|
35
|
-
if x_squeeze: res = res.squeeze(-2)
|
|
36
|
-
if y_squeeze: res = res.squeeze(-1)
|
|
37
|
-
|
|
38
|
-
return res
|
|
39
|
-
|
|
40
|
-
def tropical_dot(x:torch.Tensor, y:torch.Tensor):
|
|
41
|
-
assert x.ndim == 1 and y.ndim == 1
|
|
42
|
-
return tropical_matmul(x.unsqueeze(0), y.unsqueeze(1))
|
|
43
|
-
|
|
44
|
-
def tropical_outer(x:torch.Tensor, y:torch.Tensor):
|
|
45
|
-
assert x.ndim == 1 and y.ndim == 1
|
|
46
|
-
return tropical_matmul(x.unsqueeze(1), y.unsqueeze(0))
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def tropical_solve(A: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
|
50
|
-
r = b.unsqueeze(1) - A
|
|
51
|
-
return r.amin(dim=-2)
|
|
52
|
-
|
|
53
|
-
def tropical_solve_and_reconstruct(A: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
54
|
-
r = b.unsqueeze(1) - A
|
|
55
|
-
x = r.amin(dim=-2)
|
|
56
|
-
b_hat = tropical_matmul(A, x)
|
|
57
|
-
return x, b_hat
|
|
58
|
-
|
|
59
|
-
def tikhonov(H: torch.Tensor, reg: float):
|
|
60
|
-
if reg!=0: H += torch.eye(H.size(-1), dtype=H.dtype, device=H.device) * reg
|
|
61
|
-
return H
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
class TropicalNewton(Module):
|
|
65
|
-
"""suston"""
|
|
66
|
-
def __init__(
|
|
67
|
-
self,
|
|
68
|
-
reg: float | None = None,
|
|
69
|
-
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
70
|
-
vectorize: bool = True,
|
|
71
|
-
interpolate:bool=False,
|
|
72
|
-
inner: Chainable | None = None,
|
|
73
|
-
):
|
|
74
|
-
defaults = dict(reg=reg, hessian_method=hessian_method, vectorize=vectorize, interpolate=interpolate)
|
|
75
|
-
super().__init__(defaults)
|
|
76
|
-
|
|
77
|
-
if inner is not None:
|
|
78
|
-
self.set_child('inner', inner)
|
|
79
|
-
|
|
80
|
-
@torch.no_grad
|
|
81
|
-
def step(self, vars):
|
|
82
|
-
params = TensorList(vars.params)
|
|
83
|
-
closure = vars.closure
|
|
84
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
85
|
-
|
|
86
|
-
settings = self.settings[params[0]]
|
|
87
|
-
reg = settings['reg']
|
|
88
|
-
hessian_method = settings['hessian_method']
|
|
89
|
-
vectorize = settings['vectorize']
|
|
90
|
-
interpolate = settings['interpolate']
|
|
91
|
-
|
|
92
|
-
# ------------------------ calculate grad and hessian ------------------------ #
|
|
93
|
-
if hessian_method == 'autograd':
|
|
94
|
-
with torch.enable_grad():
|
|
95
|
-
loss = vars.loss = vars.loss_approx = closure(False)
|
|
96
|
-
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
97
|
-
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
98
|
-
vars.grad = g_list
|
|
99
|
-
H = hessian_list_to_mat(H_list)
|
|
100
|
-
|
|
101
|
-
elif hessian_method in ('func', 'autograd.functional'):
|
|
102
|
-
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
103
|
-
with torch.enable_grad():
|
|
104
|
-
g_list = vars.get_grad(retain_graph=True)
|
|
105
|
-
H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
|
|
106
|
-
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
107
|
-
|
|
108
|
-
else:
|
|
109
|
-
raise ValueError(hessian_method)
|
|
110
|
-
|
|
111
|
-
# -------------------------------- inner step -------------------------------- #
|
|
112
|
-
if 'inner' in self.children:
|
|
113
|
-
g_list = apply(self.children['inner'], list(g_list), params=params, grads=list(g_list), vars=vars)
|
|
114
|
-
g = torch.cat([t.view(-1) for t in g_list])
|
|
115
|
-
|
|
116
|
-
# ------------------------------- regulazition ------------------------------- #
|
|
117
|
-
if reg is not None: H = tikhonov(H, reg)
|
|
118
|
-
|
|
119
|
-
# ----------------------------------- solve ---------------------------------- #
|
|
120
|
-
tropical_update, g_hat = tropical_solve_and_reconstruct(H, g)
|
|
121
|
-
|
|
122
|
-
g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
|
|
123
|
-
abs_error = torch.linalg.vector_norm(g-g_hat) # pylint:disable=not-callable
|
|
124
|
-
rel_error = abs_error/g_norm.clip(min=1e-8)
|
|
125
|
-
|
|
126
|
-
if interpolate:
|
|
127
|
-
if rel_error > 1e-8:
|
|
128
|
-
|
|
129
|
-
update = cholesky_solve(H, g)
|
|
130
|
-
if update is None: update = lu_solve(H, g)
|
|
131
|
-
if update is None: update = least_squares_solve(H, g)
|
|
132
|
-
|
|
133
|
-
tropical_update.lerp_(update.ravel(), rel_error.clip(max=1))
|
|
134
|
-
|
|
135
|
-
vars.update = vec_to_tensors(tropical_update, params)
|
|
136
|
-
return vars
|
torchzero/modules/lr/__init__.py
DELETED
torchzero/modules/lr/lr.py
DELETED
|
@@ -1,59 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
from ...core import Transform
|
|
4
|
-
from ...utils import NumberList, TensorList, generic_eq
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
|
|
8
|
-
"""multiplies by lr if lr is not 1"""
|
|
9
|
-
if generic_eq(lr, 1): return tensors
|
|
10
|
-
if inplace: return tensors.mul_(lr)
|
|
11
|
-
return tensors * lr
|
|
12
|
-
|
|
13
|
-
class LR(Transform):
|
|
14
|
-
def __init__(self, lr: float):
|
|
15
|
-
defaults=dict(lr=lr)
|
|
16
|
-
super().__init__(defaults, uses_grad=False)
|
|
17
|
-
|
|
18
|
-
@torch.no_grad
|
|
19
|
-
def transform(self, tensors, params, grads, vars):
|
|
20
|
-
return lazy_lr(TensorList(tensors), lr=self.get_settings('lr', params=params), inplace=True)
|
|
21
|
-
|
|
22
|
-
class StepSize(Transform):
|
|
23
|
-
"""this is exactly the same as LR, except the `lr` parameter can be renamed to any other name"""
|
|
24
|
-
def __init__(self, step_size: float, key = 'step_size'):
|
|
25
|
-
defaults={"key": key, key: step_size}
|
|
26
|
-
super().__init__(defaults, uses_grad=False)
|
|
27
|
-
|
|
28
|
-
@torch.no_grad
|
|
29
|
-
def transform(self, tensors, params, grads, vars):
|
|
30
|
-
lrs = []
|
|
31
|
-
for p in params:
|
|
32
|
-
settings = self.settings[p]
|
|
33
|
-
lrs.append(settings[settings['key']])
|
|
34
|
-
return lazy_lr(TensorList(tensors), lr=lrs, inplace=True)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def warmup(step: int, start_lr: float | NumberList, end_lr: float | NumberList, steps: float):
|
|
38
|
-
"""returns warm up lr scalar"""
|
|
39
|
-
if step > steps: return end_lr
|
|
40
|
-
return start_lr + (end_lr - start_lr) * (step / steps)
|
|
41
|
-
|
|
42
|
-
class Warmup(Transform):
|
|
43
|
-
def __init__(self, start_lr = 1e-5, end_lr:float = 1, steps = 100):
|
|
44
|
-
defaults = dict(start_lr=start_lr,end_lr=end_lr, steps=steps)
|
|
45
|
-
super().__init__(defaults, uses_grad=False)
|
|
46
|
-
|
|
47
|
-
@torch.no_grad
|
|
48
|
-
def transform(self, tensors, params, grads, vars):
|
|
49
|
-
start_lr, end_lr = self.get_settings('start_lr', 'end_lr', params=params, cls = NumberList)
|
|
50
|
-
num_steps = self.settings[params[0]]['steps']
|
|
51
|
-
step = self.global_state.get('step', 0)
|
|
52
|
-
|
|
53
|
-
target = lazy_lr(
|
|
54
|
-
TensorList(tensors),
|
|
55
|
-
lr=warmup(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
|
|
56
|
-
inplace=True
|
|
57
|
-
)
|
|
58
|
-
self.global_state['step'] = step + 1
|
|
59
|
-
return target
|
|
@@ -1,97 +0,0 @@
|
|
|
1
|
-
import random
|
|
2
|
-
from typing import Any
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...core import Transform
|
|
7
|
-
from ...utils import TensorList, NumberList
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class PolyakStepSize(Transform):
|
|
11
|
-
"""Polyak step-size.
|
|
12
|
-
|
|
13
|
-
Args:
|
|
14
|
-
max (float | None, optional): maximum possible step size. Defaults to None.
|
|
15
|
-
min_obj_value (int, optional): (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
|
|
16
|
-
use_grad (bool, optional):
|
|
17
|
-
if True, uses dot product of update and gradient to compute the step size.
|
|
18
|
-
Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
|
|
19
|
-
Defaults to True.
|
|
20
|
-
parameterwise (bool, optional):
|
|
21
|
-
if True, calculate Polyak step-size for each parameter separately,
|
|
22
|
-
if False calculate one global step size for all parameters. Defaults to False.
|
|
23
|
-
alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
|
|
24
|
-
"""
|
|
25
|
-
def __init__(self, max: float | None = None, min_obj_value: float = 0, use_grad=True, parameterwise=False, alpha: float = 1):
|
|
26
|
-
|
|
27
|
-
defaults = dict(alpha=alpha, max=max, min_obj_value=min_obj_value, use_grad=use_grad, parameterwise=parameterwise)
|
|
28
|
-
super().__init__(defaults, uses_grad=use_grad)
|
|
29
|
-
|
|
30
|
-
@torch.no_grad
|
|
31
|
-
def transform(self, tensors, params, grads, vars):
|
|
32
|
-
loss = vars.get_loss(False)
|
|
33
|
-
assert grads is not None
|
|
34
|
-
tensors = TensorList(tensors)
|
|
35
|
-
grads = TensorList(grads)
|
|
36
|
-
alpha = self.get_settings('alpha', params=params, cls=NumberList)
|
|
37
|
-
settings = self.settings[params[0]]
|
|
38
|
-
parameterwise = settings['parameterwise']
|
|
39
|
-
use_grad = settings['use_grad']
|
|
40
|
-
max = settings['max']
|
|
41
|
-
min_obj_value = settings['min_obj_value']
|
|
42
|
-
|
|
43
|
-
if parameterwise:
|
|
44
|
-
if use_grad: denom = (tensors * grads).sum()
|
|
45
|
-
else: denom = tensors.pow(2).sum()
|
|
46
|
-
polyak_step_size: TensorList | Any = (loss - min_obj_value) / denom.where(denom!=0, 1)
|
|
47
|
-
polyak_step_size = polyak_step_size.where(denom != 0, 0)
|
|
48
|
-
if max is not None: polyak_step_size = polyak_step_size.clamp_max(max)
|
|
49
|
-
|
|
50
|
-
else:
|
|
51
|
-
if use_grad: denom = tensors.dot(grads)
|
|
52
|
-
else: denom = tensors.dot(tensors)
|
|
53
|
-
if denom == 0: polyak_step_size = 0 # we converged
|
|
54
|
-
else: polyak_step_size = (loss - min_obj_value) / denom
|
|
55
|
-
|
|
56
|
-
if max is not None:
|
|
57
|
-
if polyak_step_size > max: polyak_step_size = max
|
|
58
|
-
|
|
59
|
-
tensors.mul_(alpha * polyak_step_size)
|
|
60
|
-
return tensors
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
class RandomStepSize(Transform):
|
|
65
|
-
"""Uses random global step size from `low` to `high`.
|
|
66
|
-
|
|
67
|
-
Args:
|
|
68
|
-
low (float, optional): minimum learning rate. Defaults to 0.
|
|
69
|
-
high (float, optional): maximum learning rate. Defaults to 1.
|
|
70
|
-
parameterwise (bool, optional):
|
|
71
|
-
if True, generate random step size for each parameter separately,
|
|
72
|
-
if False generate one global random step size. Defaults to False.
|
|
73
|
-
"""
|
|
74
|
-
def __init__(self, low: float = 0, high: float = 1, parameterwise=False, seed:int|None=None):
|
|
75
|
-
defaults = dict(low=low, high=high, parameterwise=parameterwise,seed=seed)
|
|
76
|
-
super().__init__(defaults, uses_grad=False)
|
|
77
|
-
|
|
78
|
-
@torch.no_grad
|
|
79
|
-
def transform(self, tensors, params, grads, vars):
|
|
80
|
-
settings = self.settings[params[0]]
|
|
81
|
-
parameterwise = settings['parameterwise']
|
|
82
|
-
|
|
83
|
-
seed = settings['seed']
|
|
84
|
-
if 'generator' not in self.global_state:
|
|
85
|
-
self.global_state['generator'] = random.Random(seed)
|
|
86
|
-
generator: random.Random = self.global_state['generator']
|
|
87
|
-
|
|
88
|
-
if parameterwise:
|
|
89
|
-
low, high = self.get_settings('low', 'high', params=params)
|
|
90
|
-
lr = [generator.uniform(l, h) for l, h in zip(low, high)]
|
|
91
|
-
else:
|
|
92
|
-
low = settings['low']
|
|
93
|
-
high = settings['high']
|
|
94
|
-
lr = generator.uniform(low, high)
|
|
95
|
-
|
|
96
|
-
torch._foreach_mul_(tensors, lr)
|
|
97
|
-
return tensors
|