torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
torchzero/utils/linalg/solve.py
CHANGED
|
@@ -1,12 +1,41 @@
|
|
|
1
|
+
# pyright: reportArgumentType=false
|
|
1
2
|
from collections.abc import Callable
|
|
2
|
-
from typing import overload
|
|
3
|
+
from typing import Any, overload
|
|
4
|
+
|
|
3
5
|
import torch
|
|
4
6
|
|
|
5
|
-
from .. import
|
|
7
|
+
from .. import (
|
|
8
|
+
TensorList,
|
|
9
|
+
generic_eq,
|
|
10
|
+
generic_finfo_eps,
|
|
11
|
+
generic_numel,
|
|
12
|
+
generic_randn_like,
|
|
13
|
+
generic_vector_norm,
|
|
14
|
+
generic_zeros_like,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _make_A_mm_reg(A_mm: Callable | torch.Tensor, reg):
|
|
19
|
+
if callable(A_mm):
|
|
20
|
+
def A_mm_reg(x): # A_mm with regularization
|
|
21
|
+
Ax = A_mm(x)
|
|
22
|
+
if not generic_eq(reg, 0): Ax += x*reg
|
|
23
|
+
return Ax
|
|
24
|
+
return A_mm_reg
|
|
25
|
+
|
|
26
|
+
if not isinstance(A_mm, torch.Tensor): raise TypeError(type(A_mm))
|
|
27
|
+
|
|
28
|
+
def Ax_reg(x): # A_mm with regularization
|
|
29
|
+
if A_mm.ndim == 1: Ax = A_mm * x
|
|
30
|
+
else: Ax = A_mm @ x
|
|
31
|
+
if reg != 0: Ax += x*reg
|
|
32
|
+
return Ax
|
|
33
|
+
return Ax_reg
|
|
34
|
+
|
|
6
35
|
|
|
7
36
|
@overload
|
|
8
37
|
def cg(
|
|
9
|
-
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
38
|
+
A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
|
|
10
39
|
b: torch.Tensor,
|
|
11
40
|
x0_: torch.Tensor | None = None,
|
|
12
41
|
tol: float | None = 1e-4,
|
|
@@ -24,17 +53,17 @@ def cg(
|
|
|
24
53
|
) -> TensorList: ...
|
|
25
54
|
|
|
26
55
|
def cg(
|
|
27
|
-
A_mm: Callable,
|
|
56
|
+
A_mm: Callable | torch.Tensor,
|
|
28
57
|
b: torch.Tensor | TensorList,
|
|
29
58
|
x0_: torch.Tensor | TensorList | None = None,
|
|
30
59
|
tol: float | None = 1e-4,
|
|
31
60
|
maxiter: int | None = None,
|
|
32
61
|
reg: float | list[float] | tuple[float] = 0,
|
|
33
62
|
):
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
63
|
+
A_mm_reg = _make_A_mm_reg(A_mm, reg)
|
|
64
|
+
eps = generic_finfo_eps(b)
|
|
65
|
+
|
|
66
|
+
if tol is None: tol = eps
|
|
38
67
|
|
|
39
68
|
if maxiter is None: maxiter = generic_numel(b)
|
|
40
69
|
if x0_ is None: x0_ = generic_zeros_like(b)
|
|
@@ -44,9 +73,10 @@ def cg(
|
|
|
44
73
|
p = residual.clone() # search direction
|
|
45
74
|
r_norm = generic_vector_norm(residual)
|
|
46
75
|
init_norm = r_norm
|
|
47
|
-
if
|
|
76
|
+
if r_norm < tol: return x
|
|
48
77
|
k = 0
|
|
49
78
|
|
|
79
|
+
|
|
50
80
|
while True:
|
|
51
81
|
Ap = A_mm_reg(p)
|
|
52
82
|
step_size = (r_norm**2) / p.dot(Ap)
|
|
@@ -55,7 +85,7 @@ def cg(
|
|
|
55
85
|
new_r_norm = generic_vector_norm(residual)
|
|
56
86
|
|
|
57
87
|
k += 1
|
|
58
|
-
if
|
|
88
|
+
if new_r_norm <= tol * init_norm: return x
|
|
59
89
|
if k >= maxiter: return x
|
|
60
90
|
|
|
61
91
|
beta = (new_r_norm**2) / (r_norm**2)
|
|
@@ -131,6 +161,8 @@ def nystrom_pcg(
|
|
|
131
161
|
generator=generator,
|
|
132
162
|
)
|
|
133
163
|
lambd += reg
|
|
164
|
+
eps = torch.finfo(b.dtype).eps ** 2
|
|
165
|
+
if tol is None: tol = eps
|
|
134
166
|
|
|
135
167
|
def A_mm_reg(x): # A_mm with regularization
|
|
136
168
|
Ax = A_mm(x)
|
|
@@ -150,7 +182,7 @@ def nystrom_pcg(
|
|
|
150
182
|
p = z.clone() # search direction
|
|
151
183
|
|
|
152
184
|
init_norm = torch.linalg.vector_norm(residual) # pylint:disable=not-callable
|
|
153
|
-
if
|
|
185
|
+
if init_norm < tol: return x
|
|
154
186
|
k = 0
|
|
155
187
|
while True:
|
|
156
188
|
Ap = A_mm_reg(p)
|
|
@@ -160,10 +192,217 @@ def nystrom_pcg(
|
|
|
160
192
|
residual -= step_size * Ap
|
|
161
193
|
|
|
162
194
|
k += 1
|
|
163
|
-
if
|
|
195
|
+
if torch.linalg.vector_norm(residual) <= tol * init_norm: return x # pylint:disable=not-callable
|
|
164
196
|
if k >= maxiter: return x
|
|
165
197
|
|
|
166
198
|
z = P_inv @ residual
|
|
167
199
|
beta = residual.dot(z) / rz
|
|
168
200
|
p = z + p*beta
|
|
169
201
|
|
|
202
|
+
|
|
203
|
+
def _safe_clip(x: torch.Tensor):
|
|
204
|
+
"""makes sure scalar tensor x is not smaller than epsilon"""
|
|
205
|
+
assert x.numel() == 1, x.shape
|
|
206
|
+
eps = torch.finfo(x.dtype).eps
|
|
207
|
+
if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
|
|
208
|
+
return x
|
|
209
|
+
|
|
210
|
+
def _trust_tau(x,d,trust_region):
|
|
211
|
+
xx = x.dot(x)
|
|
212
|
+
xd = x.dot(d)
|
|
213
|
+
dd = _safe_clip(d.dot(d))
|
|
214
|
+
|
|
215
|
+
rad = (xd**2 - dd * (xx - trust_region**2)).clip(min=0).sqrt()
|
|
216
|
+
tau = (-xd + rad) / dd
|
|
217
|
+
|
|
218
|
+
return x + tau * d
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@overload
|
|
222
|
+
def steihaug_toint_cg(
|
|
223
|
+
A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
|
|
224
|
+
b: torch.Tensor,
|
|
225
|
+
trust_region: float,
|
|
226
|
+
x0: torch.Tensor | None = None,
|
|
227
|
+
tol: float | None = 1e-4,
|
|
228
|
+
maxiter: int | None = None,
|
|
229
|
+
reg: float = 0,
|
|
230
|
+
) -> torch.Tensor: ...
|
|
231
|
+
@overload
|
|
232
|
+
def steihaug_toint_cg(
|
|
233
|
+
A_mm: Callable[[TensorList], TensorList],
|
|
234
|
+
b: TensorList,
|
|
235
|
+
trust_region: float,
|
|
236
|
+
x0: TensorList | None = None,
|
|
237
|
+
tol: float | None = 1e-4,
|
|
238
|
+
maxiter: int | None = None,
|
|
239
|
+
reg: float | list[float] | tuple[float] = 0,
|
|
240
|
+
) -> TensorList: ...
|
|
241
|
+
def steihaug_toint_cg(
|
|
242
|
+
A_mm: Callable | torch.Tensor,
|
|
243
|
+
b: torch.Tensor | TensorList,
|
|
244
|
+
trust_region: float,
|
|
245
|
+
x0: torch.Tensor | TensorList | None = None,
|
|
246
|
+
tol: float | None = 1e-4,
|
|
247
|
+
maxiter: int | None = None,
|
|
248
|
+
reg: float | list[float] | tuple[float] = 0,
|
|
249
|
+
):
|
|
250
|
+
"""
|
|
251
|
+
Solution is bounded to have L2 norm no larger than :code:`trust_region`. If solution exceeds :code:`trust_region`, CG is terminated early, so it is also faster.
|
|
252
|
+
"""
|
|
253
|
+
A_mm_reg = _make_A_mm_reg(A_mm, reg)
|
|
254
|
+
|
|
255
|
+
x = x0
|
|
256
|
+
if x is None: x = generic_zeros_like(b)
|
|
257
|
+
r = b
|
|
258
|
+
d = r.clone()
|
|
259
|
+
|
|
260
|
+
eps = generic_finfo_eps(b)**2
|
|
261
|
+
if tol is None: tol = eps
|
|
262
|
+
|
|
263
|
+
if generic_vector_norm(r) < tol:
|
|
264
|
+
return x
|
|
265
|
+
|
|
266
|
+
if maxiter is None:
|
|
267
|
+
maxiter = generic_numel(b)
|
|
268
|
+
|
|
269
|
+
for _ in range(maxiter):
|
|
270
|
+
Ad = A_mm_reg(d)
|
|
271
|
+
|
|
272
|
+
d_Ad = d.dot(Ad)
|
|
273
|
+
if d_Ad <= eps:
|
|
274
|
+
return _trust_tau(x, d, trust_region)
|
|
275
|
+
|
|
276
|
+
alpha = r.dot(r) / d_Ad
|
|
277
|
+
p_next = x + alpha * d
|
|
278
|
+
|
|
279
|
+
# check if the step exceeds the trust-region boundary
|
|
280
|
+
if generic_vector_norm(p_next) >= trust_region:
|
|
281
|
+
return _trust_tau(x, d, trust_region)
|
|
282
|
+
|
|
283
|
+
# update step, residual and direction
|
|
284
|
+
x = p_next
|
|
285
|
+
r_next = r - alpha * Ad
|
|
286
|
+
|
|
287
|
+
if generic_vector_norm(r_next) < tol:
|
|
288
|
+
return x
|
|
289
|
+
|
|
290
|
+
beta = r_next.dot(r_next) / r.dot(r)
|
|
291
|
+
d = r_next + beta * d
|
|
292
|
+
r = r_next
|
|
293
|
+
|
|
294
|
+
return x
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
# Liu, Yang, and Fred Roosta. "MINRES: From negative curvature detection to monotonicity properties." SIAM Journal on Optimization 32.4 (2022): 2636-2661.
|
|
299
|
+
@overload
|
|
300
|
+
def minres(
|
|
301
|
+
A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
|
|
302
|
+
b: torch.Tensor,
|
|
303
|
+
x0: torch.Tensor | None = None,
|
|
304
|
+
tol: float | None = 1e-4,
|
|
305
|
+
maxiter: int | None = None,
|
|
306
|
+
reg: float = 0,
|
|
307
|
+
npc_terminate: bool=True,
|
|
308
|
+
trust_region: float | None = None,
|
|
309
|
+
) -> torch.Tensor: ...
|
|
310
|
+
@overload
|
|
311
|
+
def minres(
|
|
312
|
+
A_mm: Callable[[TensorList], TensorList],
|
|
313
|
+
b: TensorList,
|
|
314
|
+
x0: TensorList | None = None,
|
|
315
|
+
tol: float | None = 1e-4,
|
|
316
|
+
maxiter: int | None = None,
|
|
317
|
+
reg: float | list[float] | tuple[float] = 0,
|
|
318
|
+
npc_terminate: bool=True,
|
|
319
|
+
trust_region: float | None = None,
|
|
320
|
+
) -> TensorList: ...
|
|
321
|
+
def minres(
|
|
322
|
+
A_mm,
|
|
323
|
+
b,
|
|
324
|
+
x0: torch.Tensor | TensorList | None = None,
|
|
325
|
+
tol: float | None = 1e-4,
|
|
326
|
+
maxiter: int | None = None,
|
|
327
|
+
reg: float | list[float] | tuple[float] = 0,
|
|
328
|
+
npc_terminate: bool=True,
|
|
329
|
+
trust_region: float | None = None,
|
|
330
|
+
):
|
|
331
|
+
A_mm_reg = _make_A_mm_reg(A_mm, reg)
|
|
332
|
+
eps = generic_finfo_eps(b)
|
|
333
|
+
if tol is None: tol = eps**2
|
|
334
|
+
|
|
335
|
+
if maxiter is None: maxiter = generic_numel(b)
|
|
336
|
+
if x0 is None:
|
|
337
|
+
R = b
|
|
338
|
+
x0 = generic_zeros_like(b)
|
|
339
|
+
else:
|
|
340
|
+
R = b - A_mm_reg(x0)
|
|
341
|
+
|
|
342
|
+
X: Any = x0
|
|
343
|
+
beta = b_norm = generic_vector_norm(b)
|
|
344
|
+
if b_norm < eps**2:
|
|
345
|
+
return generic_zeros_like(b)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
V = b / beta
|
|
349
|
+
V_prev = generic_zeros_like(b)
|
|
350
|
+
D = generic_zeros_like(b)
|
|
351
|
+
D_prev = generic_zeros_like(b)
|
|
352
|
+
|
|
353
|
+
c = -1
|
|
354
|
+
phi = tau = beta
|
|
355
|
+
s = delta1 = e = 0
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
for _ in range(maxiter):
|
|
359
|
+
|
|
360
|
+
P = A_mm_reg(V)
|
|
361
|
+
alpha = V.dot(P)
|
|
362
|
+
P -= beta*V_prev
|
|
363
|
+
P -= alpha*V
|
|
364
|
+
beta = generic_vector_norm(P)
|
|
365
|
+
|
|
366
|
+
delta2 = c*delta1 + s*alpha
|
|
367
|
+
gamma1 = s*delta1 - c*alpha
|
|
368
|
+
e_next = s*beta
|
|
369
|
+
delta1 = -c*beta
|
|
370
|
+
|
|
371
|
+
cgamma1 = c*gamma1
|
|
372
|
+
if trust_region is not None and cgamma1 >= 0:
|
|
373
|
+
if npc_terminate: return _trust_tau(X, R, trust_region)
|
|
374
|
+
return _trust_tau(X, D, trust_region)
|
|
375
|
+
|
|
376
|
+
if npc_terminate and cgamma1 >= 0:
|
|
377
|
+
return R
|
|
378
|
+
|
|
379
|
+
gamma2 = (gamma1**2 + beta**2)**(1/2)
|
|
380
|
+
|
|
381
|
+
if abs(gamma2) <= eps: # singular system
|
|
382
|
+
# c=0; s=1; tau=0
|
|
383
|
+
if trust_region is None: return X
|
|
384
|
+
return _trust_tau(X, D, trust_region)
|
|
385
|
+
|
|
386
|
+
c = gamma1 / gamma2
|
|
387
|
+
s = beta/gamma2
|
|
388
|
+
tau = c*phi
|
|
389
|
+
phi = s*phi
|
|
390
|
+
|
|
391
|
+
D_prev = D
|
|
392
|
+
D = (V - delta2*D - e*D_prev) / gamma2
|
|
393
|
+
e = e_next
|
|
394
|
+
X = X + tau*D
|
|
395
|
+
|
|
396
|
+
if trust_region is not None:
|
|
397
|
+
if generic_vector_norm(X) > trust_region:
|
|
398
|
+
return _trust_tau(X, D, trust_region)
|
|
399
|
+
|
|
400
|
+
if (abs(beta) < eps) or (phi / b_norm <= tol):
|
|
401
|
+
# R = zeros(R)
|
|
402
|
+
return X
|
|
403
|
+
|
|
404
|
+
V_prev = V
|
|
405
|
+
V = P/beta
|
|
406
|
+
R = s**2*R - phi*c*V
|
|
407
|
+
|
|
408
|
+
return X
|
torchzero/utils/numberlist.py
CHANGED
|
@@ -129,4 +129,6 @@ class NumberList(list[int | float | Any]):
|
|
|
129
129
|
return self.__class__(fn(i, *args, **kwargs) for i in self)
|
|
130
130
|
|
|
131
131
|
def clamp(self, min=None, max=None):
|
|
132
|
+
return self.zipmap_args(_clamp, min, max)
|
|
133
|
+
def clip(self, min=None, max=None):
|
|
132
134
|
return self.zipmap_args(_clamp, min, max)
|
torchzero/utils/optimizer.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
1
2
|
from collections.abc import Callable, Iterable, Mapping, MutableSequence, Sequence, MutableMapping
|
|
2
3
|
from typing import Any, Literal, TypeVar, overload
|
|
3
4
|
|
|
@@ -132,65 +133,7 @@ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], param
|
|
|
132
133
|
return values
|
|
133
134
|
|
|
134
135
|
|
|
135
|
-
|
|
136
|
-
def loss_at_params(closure, params: Iterable[torch.Tensor],
|
|
137
|
-
new_params: Sequence[torch.Tensor] | Any, backward: bool, restore=False):
|
|
138
|
-
params = TensorList(params)
|
|
139
|
-
|
|
140
|
-
old_params = params.clone() if restore else None
|
|
141
|
-
|
|
142
|
-
if isinstance(new_params, Sequence) and isinstance(new_params[0], torch.Tensor):
|
|
143
|
-
# when not restoring, copy new_params to params to avoid unexpected bugs due to shared storage
|
|
144
|
-
# when restoring params will be set back to old_params so its fine
|
|
145
|
-
if restore: params.set_(new_params)
|
|
146
|
-
else: params.copy_(new_params) # type:ignore
|
|
147
|
-
|
|
148
|
-
else:
|
|
149
|
-
new_params = totensor(new_params)
|
|
150
|
-
params.from_vec_(new_params)
|
|
151
|
-
|
|
152
|
-
if backward: loss = closure()
|
|
153
|
-
else: loss = closure(False)
|
|
154
|
-
|
|
155
|
-
if restore:
|
|
156
|
-
assert old_params is not None
|
|
157
|
-
params.set_(old_params)
|
|
158
|
-
|
|
159
|
-
return tofloat(loss)
|
|
160
|
-
|
|
161
|
-
def loss_grad_at_params(closure, params: Iterable[torch.Tensor], new_params: Sequence[torch.Tensor], restore=False):
|
|
162
|
-
params = TensorList(params)
|
|
163
|
-
old_params = params.clone() if restore else None
|
|
164
|
-
loss = loss_at_params(closure, params, new_params, backward=True, restore=False)
|
|
165
|
-
grad = params.ensure_grad_().grad
|
|
166
|
-
|
|
167
|
-
if restore:
|
|
168
|
-
assert old_params is not None
|
|
169
|
-
params.set_(old_params)
|
|
170
|
-
|
|
171
|
-
return loss, grad
|
|
172
|
-
|
|
173
|
-
def grad_at_params(closure, params: Iterable[torch.Tensor], new_params: Sequence[torch.Tensor], restore=False):
|
|
174
|
-
return loss_grad_at_params(closure=closure,params=params,new_params=new_params,restore=restore)[1]
|
|
175
|
-
|
|
176
|
-
def loss_grad_vec_at_params(closure, params: Iterable[torch.Tensor], new_params: Any, restore=False):
|
|
177
|
-
params = TensorList(params)
|
|
178
|
-
old_params = params.clone() if restore else None
|
|
179
|
-
loss = loss_at_params(closure, params, new_params, backward=True, restore=False)
|
|
180
|
-
grad = params.ensure_grad_().grad.to_vec()
|
|
181
|
-
|
|
182
|
-
if restore:
|
|
183
|
-
assert old_params is not None
|
|
184
|
-
params.set_(old_params)
|
|
185
|
-
|
|
186
|
-
return loss, grad
|
|
187
|
-
|
|
188
|
-
def grad_vec_at_params(closure, params: Iterable[torch.Tensor], new_params: Any, restore=False):
|
|
189
|
-
return loss_grad_vec_at_params(closure=closure,params=params,new_params=new_params,restore=restore)[1]
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
class Optimizer(torch.optim.Optimizer):
|
|
136
|
+
class Optimizer(torch.optim.Optimizer, ABC):
|
|
194
137
|
"""subclass of torch.optim.Optimizer with some helper methods for fast experimentation, it's not used anywhere in torchzero.
|
|
195
138
|
|
|
196
139
|
Args:
|
|
@@ -251,21 +194,10 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
251
194
|
|
|
252
195
|
return get_state_vals(self.state, params, key, key2, *keys, init = init, cls = cls) # type:ignore[reportArgumentType]
|
|
253
196
|
|
|
254
|
-
def loss_at_params(self, closure, params: Sequence[torch.Tensor] | Any, backward: bool, restore=False):
|
|
255
|
-
return loss_at_params(closure=closure,params=self.get_params(),new_params=params,backward=backward,restore=restore)
|
|
256
|
-
|
|
257
|
-
def loss_grad_at_params(self, closure, params: Sequence[torch.Tensor] | Any, restore=False):
|
|
258
|
-
return loss_grad_at_params(closure=closure,params=self.get_params(),new_params=params,restore=restore)
|
|
259
|
-
|
|
260
|
-
def grad_at_params(self, closure, new_params: Sequence[torch.Tensor], restore=False):
|
|
261
|
-
return self.loss_grad_at_params(closure=closure,params=new_params,restore=restore)[1]
|
|
262
|
-
|
|
263
|
-
def loss_grad_vec_at_params(self, closure, params: Any, restore=False):
|
|
264
|
-
return loss_grad_vec_at_params(closure=closure,params=self.get_params(),new_params=params,restore=restore)
|
|
265
|
-
|
|
266
|
-
def grad_vec_at_params(self, closure, params: Any, restore=False):
|
|
267
|
-
return self.loss_grad_vec_at_params(closure=closure,params=params,restore=restore)[1]
|
|
268
197
|
|
|
198
|
+
# shut up pylance
|
|
199
|
+
@abstractmethod
|
|
200
|
+
def step(self, closure) -> Any: ... # pylint:disable=signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
|
|
269
201
|
|
|
270
202
|
def zero_grad_(params: Iterable[torch.Tensor], set_to_none):
|
|
271
203
|
if set_to_none:
|
|
@@ -281,4 +213,53 @@ def zero_grad_(params: Iterable[torch.Tensor], set_to_none):
|
|
|
281
213
|
else:
|
|
282
214
|
grad.requires_grad_(False)
|
|
283
215
|
|
|
284
|
-
torch._foreach_zero_(grads)
|
|
216
|
+
torch._foreach_zero_(grads)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@overload
|
|
220
|
+
def unpack_states(states: Sequence[MutableMapping[str, Any]], tensors: Sequence[torch.Tensor],
|
|
221
|
+
key: str, *,
|
|
222
|
+
must_exist: bool = False, init: Init = torch.zeros_like,
|
|
223
|
+
cls: type[ListLike] = list) -> ListLike: ...
|
|
224
|
+
@overload
|
|
225
|
+
def unpack_states(states: Sequence[MutableMapping[str, Any]], tensors: Sequence[torch.Tensor],
|
|
226
|
+
key: list[str] | tuple[str,...], *,
|
|
227
|
+
must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
228
|
+
cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
229
|
+
@overload
|
|
230
|
+
def unpack_states(states: Sequence[MutableMapping[str, Any]], tensors: Sequence[torch.Tensor],
|
|
231
|
+
key: str, key2: str, *keys: str,
|
|
232
|
+
must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
233
|
+
cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
234
|
+
|
|
235
|
+
def unpack_states(states: Sequence[MutableMapping[str, Any]], tensors: Sequence[torch.Tensor],
|
|
236
|
+
key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
|
|
237
|
+
must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
238
|
+
cls: type[ListLike] = list) -> ListLike | list[ListLike]:
|
|
239
|
+
|
|
240
|
+
# single key, return single cls
|
|
241
|
+
if isinstance(key, str) and key2 is None:
|
|
242
|
+
values = cls()
|
|
243
|
+
for i,s in enumerate(states):
|
|
244
|
+
if key not in s:
|
|
245
|
+
if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
|
|
246
|
+
s[key] = _make_initial_state_value(tensors[i], init, i)
|
|
247
|
+
values.append(s[key])
|
|
248
|
+
return values
|
|
249
|
+
|
|
250
|
+
# multiple keys
|
|
251
|
+
k1 = (key,) if isinstance(key, str) else tuple(key)
|
|
252
|
+
k2 = () if key2 is None else (key2,)
|
|
253
|
+
keys = k1 + k2 + keys
|
|
254
|
+
|
|
255
|
+
values = [cls() for _ in keys]
|
|
256
|
+
for i,s in enumerate(states):
|
|
257
|
+
for k_i, key in enumerate(keys):
|
|
258
|
+
if key not in s:
|
|
259
|
+
if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
|
|
260
|
+
k_init = init[k_i] if isinstance(init, (list,tuple)) else init
|
|
261
|
+
s[key] = _make_initial_state_value(tensors[i], k_init, i)
|
|
262
|
+
values[k_i].append(s[key])
|
|
263
|
+
|
|
264
|
+
return values
|
|
265
|
+
|
torchzero/utils/python_tools.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import operator
|
|
3
|
-
from typing import Any, TypeVar
|
|
4
|
-
from collections.abc import Iterable, Callable
|
|
3
|
+
from typing import Any, TypeVar, overload
|
|
4
|
+
from collections.abc import Iterable, Callable, Mapping, MutableSequence
|
|
5
5
|
from collections import UserDict
|
|
6
6
|
|
|
7
7
|
|
|
@@ -17,8 +17,8 @@ def flatten(iterable: Iterable) -> list[Any]:
|
|
|
17
17
|
raise TypeError(f'passed object is not an iterable, {type(iterable) = }')
|
|
18
18
|
|
|
19
19
|
X = TypeVar("X")
|
|
20
|
-
# def reduce_dim[X](x:Iterable[Iterable[X]]) -> list[X]:
|
|
21
|
-
def reduce_dim(x:Iterable[Iterable[X]]) -> list[X]:
|
|
20
|
+
# def reduce_dim[X](x:Iterable[Iterable[X]]) -> list[X]:
|
|
21
|
+
def reduce_dim(x:Iterable[Iterable[X]]) -> list[X]:
|
|
22
22
|
"""Reduces one level of nesting. Takes an iterable of iterables of X, and returns an iterable of X."""
|
|
23
23
|
return functools.reduce(operator.iconcat, x, [])
|
|
24
24
|
|
|
@@ -31,6 +31,16 @@ def generic_eq(x: int | float | Iterable[int | float], y: int | float | Iterable
|
|
|
31
31
|
return all(i==y for i in x)
|
|
32
32
|
return all(i==j for i,j in zip(x,y))
|
|
33
33
|
|
|
34
|
+
def generic_ne(x: int | float | Iterable[int | float], y: int | float | Iterable[int | float]) -> bool:
|
|
35
|
+
"""generic not equals function that supports scalars and lists of numbers. Faster than not generic_eq"""
|
|
36
|
+
if isinstance(x, (int,float)):
|
|
37
|
+
if isinstance(y, (int,float)): return x!=y
|
|
38
|
+
return any(i!=x for i in y)
|
|
39
|
+
if isinstance(y, (int,float)):
|
|
40
|
+
return any(i!=y for i in x)
|
|
41
|
+
return any(i!=j for i,j in zip(x,y))
|
|
42
|
+
|
|
43
|
+
|
|
34
44
|
def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
|
|
35
45
|
"""If `other` is list/tuple, applies `fn` to self zipped with `other`.
|
|
36
46
|
Otherwise applies `fn` to this sequence and `other`.
|
|
@@ -38,3 +48,16 @@ def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
|
|
|
38
48
|
if isinstance(other, (list, tuple)): return self.__class__(fn(i, j, *args, **kwargs) for i, j in zip(self, other))
|
|
39
49
|
return self.__class__(fn(i, other, *args, **kwargs) for i in self)
|
|
40
50
|
|
|
51
|
+
ListLike = TypeVar('ListLike', bound=MutableSequence)
|
|
52
|
+
@overload
|
|
53
|
+
def unpack_dicts(dicts: Iterable[Mapping[str, Any]], key:str, *, cls:type[ListLike]=list) -> ListLike: ...
|
|
54
|
+
@overload
|
|
55
|
+
def unpack_dicts(dicts: Iterable[Mapping[str, Any]], key:str, key2: str, *keys:str, cls:type[ListLike]=list) -> list[ListLike]: ...
|
|
56
|
+
def unpack_dicts(dicts: Iterable[Mapping[str, Any]], key:str, key2: str | None = None, *keys:str, cls:type[ListLike]=list) -> ListLike | list[ListLike]:
|
|
57
|
+
k1 = (key,) if isinstance(key, str) else tuple(key)
|
|
58
|
+
k2 = () if key2 is None else (key2,)
|
|
59
|
+
keys = k1 + k2 + keys
|
|
60
|
+
|
|
61
|
+
values = [cls(s[k] for s in dicts) for k in keys] # pyright:ignore[reportCallIssue]
|
|
62
|
+
if len(values) == 1: return values[0]
|
|
63
|
+
return values
|