torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
torchzero/utils/compile.py
CHANGED
|
@@ -38,11 +38,11 @@ class _MaybeCompiledFunc:
|
|
|
38
38
|
_optional_compiler = _OptionalCompiler()
|
|
39
39
|
"""this holds .enable attribute, set to True to enable compiling for a few functions that benefit from it."""
|
|
40
40
|
|
|
41
|
-
def
|
|
41
|
+
def enable_compilation(enable: bool=True):
|
|
42
42
|
"""`enable` is False by default. When True, certain functions will be compiled, which may not work on some systems like Windows, but it usually improves performance."""
|
|
43
43
|
_optional_compiler.enable = enable
|
|
44
44
|
|
|
45
|
-
def
|
|
45
|
+
def allow_compile(fn): return _optional_compiler.enable_compilation(fn)
|
|
46
46
|
|
|
47
47
|
def benchmark_compile_cuda(fn, n: int, **kwargs):
|
|
48
48
|
# warmup
|
torchzero/utils/derivatives.py
CHANGED
|
@@ -4,9 +4,10 @@ import torch
|
|
|
4
4
|
import torch.autograd.forward_ad as fwAD
|
|
5
5
|
|
|
6
6
|
from .torch_tools import swap_tensors_no_use_count_check, vec_to_tensors
|
|
7
|
+
from .tensorlist import TensorList
|
|
7
8
|
|
|
8
9
|
def _jacobian(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
|
|
9
|
-
flat_outputs = torch.cat([i.
|
|
10
|
+
flat_outputs = torch.cat([i.ravel() for i in outputs])
|
|
10
11
|
grad_ouputs = torch.eye(len(flat_outputs), device=outputs[0].device, dtype=outputs[0].dtype)
|
|
11
12
|
jac = []
|
|
12
13
|
for i in range(flat_outputs.numel()):
|
|
@@ -23,7 +24,7 @@ def _jacobian(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], crea
|
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
def _jacobian_batched(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
|
|
26
|
-
flat_outputs = torch.cat([i.
|
|
27
|
+
flat_outputs = torch.cat([i.ravel() for i in outputs])
|
|
27
28
|
return torch.autograd.grad(
|
|
28
29
|
flat_outputs,
|
|
29
30
|
wrt,
|
|
@@ -39,10 +40,10 @@ def flatten_jacobian(jacs: Sequence[torch.Tensor]) -> torch.Tensor:
|
|
|
39
40
|
|
|
40
41
|
Args:
|
|
41
42
|
jacs (Sequence[torch.Tensor]):
|
|
42
|
-
output from jacobian_wrt where ach tensor has the shape
|
|
43
|
+
output from jacobian_wrt where ach tensor has the shape ``(*output.shape, *wrt[i].shape)``.
|
|
43
44
|
|
|
44
45
|
Returns:
|
|
45
|
-
torch.Tensor: has the shape
|
|
46
|
+
torch.Tensor: has the shape ``(output.ndim, wrt.ndim)``.
|
|
46
47
|
"""
|
|
47
48
|
if not jacs:
|
|
48
49
|
return torch.empty(0, 0)
|
|
@@ -261,7 +262,7 @@ def jvp_fd_central(
|
|
|
261
262
|
params: Iterable[torch.Tensor],
|
|
262
263
|
tangent: Iterable[torch.Tensor],
|
|
263
264
|
h=1e-3,
|
|
264
|
-
normalize=
|
|
265
|
+
normalize=True,
|
|
265
266
|
) -> tuple[torch.Tensor | None, torch.Tensor]:
|
|
266
267
|
"""Jacobian vector product using central finite difference formula.
|
|
267
268
|
|
|
@@ -310,7 +311,7 @@ def jvp_fd_forward(
|
|
|
310
311
|
tangent: Iterable[torch.Tensor],
|
|
311
312
|
h=1e-3,
|
|
312
313
|
v_0=None,
|
|
313
|
-
normalize=
|
|
314
|
+
normalize=True,
|
|
314
315
|
) -> tuple[torch.Tensor | None, torch.Tensor]:
|
|
315
316
|
"""Jacobian vector product using forward finite difference formula.
|
|
316
317
|
Loss at initial point can be specified in the `v_0` argument.
|
|
@@ -357,52 +358,18 @@ def jvp_fd_forward(
|
|
|
357
358
|
if normalize: res = res * tangent_norm
|
|
358
359
|
return v_0, res
|
|
359
360
|
|
|
360
|
-
def hvp(
|
|
361
|
-
params: Iterable[torch.Tensor],
|
|
362
|
-
grads: Iterable[torch.Tensor],
|
|
363
|
-
vec: Iterable[torch.Tensor],
|
|
364
|
-
retain_graph=None,
|
|
365
|
-
create_graph=False,
|
|
366
|
-
allow_unused=None,
|
|
367
|
-
):
|
|
368
|
-
"""Hessian-vector product
|
|
369
|
-
|
|
370
|
-
Example:
|
|
371
|
-
```python
|
|
372
|
-
model = nn.Linear(4, 2)
|
|
373
|
-
X = torch.randn(10, 4)
|
|
374
|
-
y = torch.randn(10, 2)
|
|
375
|
-
|
|
376
|
-
y_hat = model(X)
|
|
377
|
-
loss = F.mse_loss(y_hat, y)
|
|
378
|
-
loss.backward(create_graph=True)
|
|
379
|
-
|
|
380
|
-
grads = [p.grad for p in model.parameters()]
|
|
381
|
-
vec = [torch.randn_like(p) for p in model.parameters()]
|
|
382
|
-
|
|
383
|
-
# list of tensors, same layout as model.parameters()
|
|
384
|
-
hvp(model.parameters(), grads, vec=vec)
|
|
385
|
-
```
|
|
386
|
-
"""
|
|
387
|
-
params = list(params)
|
|
388
|
-
g = list(grads)
|
|
389
|
-
vec = list(vec)
|
|
390
|
-
|
|
391
|
-
with torch.enable_grad():
|
|
392
|
-
return torch.autograd.grad(g, params, vec, create_graph=create_graph, retain_graph=retain_graph, allow_unused=allow_unused)
|
|
393
|
-
|
|
394
361
|
|
|
395
362
|
@torch.no_grad
|
|
396
363
|
def hvp_fd_central(
|
|
397
364
|
closure,
|
|
398
365
|
params: Iterable[torch.Tensor],
|
|
399
|
-
|
|
366
|
+
x: Iterable[torch.Tensor],
|
|
400
367
|
h=1e-3,
|
|
401
|
-
normalize=
|
|
368
|
+
normalize=True,
|
|
402
369
|
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
|
|
403
|
-
"""
|
|
370
|
+
"""Returns ``(loss_approx, Hx)``.
|
|
404
371
|
|
|
405
|
-
Please note that this will clear
|
|
372
|
+
Please note that this will clear ``grad`` attributes in params.
|
|
406
373
|
|
|
407
374
|
Example:
|
|
408
375
|
```python
|
|
@@ -424,48 +391,48 @@ def hvp_fd_central(
|
|
|
424
391
|
```
|
|
425
392
|
"""
|
|
426
393
|
params = list(params)
|
|
427
|
-
|
|
394
|
+
x = list(x)
|
|
428
395
|
|
|
429
396
|
vec_norm = None
|
|
430
397
|
if normalize:
|
|
431
|
-
vec_norm = torch.linalg.vector_norm(torch.cat([t.view(-1) for t in
|
|
398
|
+
vec_norm = torch.linalg.vector_norm(torch.cat([t.view(-1) for t in x])) # pylint:disable=not-callable
|
|
432
399
|
if vec_norm == 0: return None, [torch.zeros_like(p) for p in params]
|
|
433
|
-
|
|
400
|
+
x = torch._foreach_div(x, vec_norm)
|
|
434
401
|
|
|
435
|
-
|
|
436
|
-
torch._foreach_add_(params,
|
|
402
|
+
xh = torch._foreach_mul(x, h)
|
|
403
|
+
torch._foreach_add_(params, xh)
|
|
437
404
|
with torch.enable_grad(): loss = closure()
|
|
438
405
|
g_plus = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
439
406
|
|
|
440
|
-
torch._foreach_sub_(params,
|
|
441
|
-
torch._foreach_sub_(params,
|
|
407
|
+
torch._foreach_sub_(params, xh)
|
|
408
|
+
torch._foreach_sub_(params, xh)
|
|
442
409
|
with torch.enable_grad(): loss = closure()
|
|
443
410
|
g_minus = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
444
411
|
|
|
445
|
-
torch._foreach_add_(params,
|
|
412
|
+
torch._foreach_add_(params, xh)
|
|
446
413
|
for p in params: p.grad = None
|
|
447
414
|
|
|
448
|
-
|
|
449
|
-
torch._foreach_sub_(
|
|
450
|
-
torch._foreach_div_(
|
|
415
|
+
hx = g_plus
|
|
416
|
+
torch._foreach_sub_(hx, g_minus)
|
|
417
|
+
torch._foreach_div_(hx, 2*h)
|
|
451
418
|
|
|
452
|
-
if normalize: torch._foreach_mul_(
|
|
453
|
-
return loss,
|
|
419
|
+
if normalize: torch._foreach_mul_(hx, vec_norm)
|
|
420
|
+
return loss, hx
|
|
454
421
|
|
|
455
422
|
@torch.no_grad
|
|
456
423
|
def hvp_fd_forward(
|
|
457
424
|
closure,
|
|
458
425
|
params: Iterable[torch.Tensor],
|
|
459
|
-
|
|
426
|
+
x: Iterable[torch.Tensor],
|
|
460
427
|
h=1e-3,
|
|
461
428
|
g_0=None,
|
|
462
|
-
normalize=
|
|
429
|
+
normalize=True,
|
|
463
430
|
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
|
|
464
|
-
"""
|
|
431
|
+
"""Returns ``(loss_approx, Hx)``.
|
|
465
432
|
|
|
466
|
-
Gradient at initial point can be specified in the
|
|
433
|
+
Gradient at initial point can be specified in the ``g_0`` argument.
|
|
467
434
|
|
|
468
|
-
Please note that this will clear
|
|
435
|
+
Please note that this will clear ``grad`` attributes in params.
|
|
469
436
|
|
|
470
437
|
Example:
|
|
471
438
|
```python
|
|
@@ -492,16 +459,16 @@ def hvp_fd_forward(
|
|
|
492
459
|
"""
|
|
493
460
|
|
|
494
461
|
params = list(params)
|
|
495
|
-
|
|
462
|
+
x = list(x)
|
|
496
463
|
loss = None
|
|
497
464
|
|
|
498
465
|
vec_norm = None
|
|
499
466
|
if normalize:
|
|
500
|
-
vec_norm = torch.linalg.vector_norm(torch.cat([t.ravel() for t in
|
|
467
|
+
vec_norm = torch.linalg.vector_norm(torch.cat([t.ravel() for t in x])) # pylint:disable=not-callable
|
|
501
468
|
if vec_norm == 0: return None, [torch.zeros_like(p) for p in params]
|
|
502
|
-
|
|
469
|
+
x = torch._foreach_div(x, vec_norm)
|
|
503
470
|
|
|
504
|
-
|
|
471
|
+
xh = torch._foreach_mul(x, h)
|
|
505
472
|
|
|
506
473
|
if g_0 is None:
|
|
507
474
|
with torch.enable_grad(): loss = closure()
|
|
@@ -509,18 +476,75 @@ def hvp_fd_forward(
|
|
|
509
476
|
else:
|
|
510
477
|
g_0 = list(g_0)
|
|
511
478
|
|
|
512
|
-
torch._foreach_add_(params,
|
|
479
|
+
torch._foreach_add_(params, xh)
|
|
513
480
|
with torch.enable_grad():
|
|
514
481
|
l = closure()
|
|
515
482
|
if loss is None: loss = l
|
|
516
483
|
g_plus = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
517
484
|
|
|
518
|
-
torch._foreach_sub_(params,
|
|
485
|
+
torch._foreach_sub_(params, xh)
|
|
519
486
|
for p in params: p.grad = None
|
|
520
487
|
|
|
521
|
-
|
|
522
|
-
torch._foreach_sub_(
|
|
523
|
-
torch._foreach_div_(
|
|
488
|
+
hx = g_plus
|
|
489
|
+
torch._foreach_sub_(hx, g_0)
|
|
490
|
+
torch._foreach_div_(hx, h)
|
|
491
|
+
|
|
492
|
+
if normalize: torch._foreach_mul_(hx, vec_norm)
|
|
493
|
+
return loss, hx
|
|
494
|
+
|
|
495
|
+
@torch.no_grad
|
|
496
|
+
def hessian_fd(fn, params: Sequence[torch.Tensor], eps: float = 1e-4, full: bool = True):
|
|
497
|
+
"""returns ``f(x), g(x), H(x)``, where ``g(x)`` is a tensor list.
|
|
498
|
+
|
|
499
|
+
Number of evals for full is: 4n^2 - 2n + 1
|
|
500
|
+
|
|
501
|
+
Number of evals for upper is: 2n^2 + 1.
|
|
502
|
+
"""
|
|
503
|
+
params = TensorList(params)
|
|
504
|
+
p_0 = params.clone()
|
|
505
|
+
n = sum(t.numel() for t in params)
|
|
506
|
+
device = params[0].device; dtype = params[0].dtype
|
|
507
|
+
fx = fn()
|
|
508
|
+
g = params.zeros_like()
|
|
509
|
+
H = torch.zeros((n, n), device=device, dtype=dtype)
|
|
510
|
+
|
|
511
|
+
for i in range(n):
|
|
512
|
+
for j in (range(n) if full else range(i, n)):
|
|
513
|
+
if i == j:
|
|
514
|
+
params.flat_set_lambda_(i, lambda x: x + eps)
|
|
515
|
+
f_plus = fn()
|
|
516
|
+
|
|
517
|
+
params.flat_set_lambda_(i, lambda x: x - 2 * eps)
|
|
518
|
+
f_minus = fn()
|
|
519
|
+
|
|
520
|
+
# params.flat_set_lambda_(i, lambda x: x + eps)
|
|
521
|
+
g.flat_set_(i, (f_plus - f_minus) / (2*eps))
|
|
522
|
+
H[i, i] = (f_plus - 2 * fx + f_minus) / (eps ** 2)
|
|
523
|
+
|
|
524
|
+
else:
|
|
525
|
+
params.flat_set_lambda_(i, lambda x: x + eps)
|
|
526
|
+
params.flat_set_lambda_(j, lambda x: x + eps)
|
|
527
|
+
f_pp = fn()
|
|
528
|
+
|
|
529
|
+
params.flat_set_lambda_(i, lambda x: x - 2 * eps)
|
|
530
|
+
f_np = fn()
|
|
531
|
+
|
|
532
|
+
params.flat_set_lambda_(j, lambda x: x - 2 * eps)
|
|
533
|
+
f_nn = fn()
|
|
534
|
+
|
|
535
|
+
params.flat_set_lambda_(i, lambda x: x + 2 * eps)
|
|
536
|
+
f_pn = fn()
|
|
537
|
+
|
|
538
|
+
# params.flat_set_lambda_(i, lambda x: x - eps)
|
|
539
|
+
# params.flat_set_lambda_(j, lambda x: x + eps)
|
|
540
|
+
|
|
541
|
+
H[i, j] = (f_pp - f_np - f_pn + f_nn) / (4 * eps ** 2)
|
|
542
|
+
if not full: H[j, i] = H[i, j]
|
|
543
|
+
|
|
544
|
+
params.copy_(p_0) # otherwise inaccuracy builds up
|
|
545
|
+
|
|
546
|
+
if full:
|
|
547
|
+
H = H + H.T
|
|
548
|
+
H /= 2
|
|
524
549
|
|
|
525
|
-
|
|
526
|
-
return loss, hvp_
|
|
550
|
+
return fx, g, H
|
torchzero/utils/optimizer.py
CHANGED
|
@@ -64,22 +64,15 @@ def get_group_vals(param_groups: Iterable[Mapping[str, Any]],
|
|
|
64
64
|
values[i].extend(group_value for _ in range(num_params))
|
|
65
65
|
return values
|
|
66
66
|
|
|
67
|
-
|
|
68
|
-
Init = _InitLiterals | Any | list[_InitLiterals | Any] | tuple[_InitLiterals | Any]
|
|
67
|
+
Init = Any
|
|
69
68
|
|
|
70
|
-
def _make_initial_state_value(
|
|
71
|
-
if callable(init): return init(
|
|
69
|
+
def _make_initial_state_value(tensor: torch.Tensor, init: Init, i: int | None):
|
|
70
|
+
if callable(init): return init(tensor)
|
|
72
71
|
if isinstance(init, torch.Tensor): return init.detach().clone()
|
|
73
72
|
|
|
74
|
-
if isinstance(init, str):
|
|
75
|
-
if init in ('param','params'): return param.detach().clone()
|
|
76
|
-
if init in ('grad', 'grads'):
|
|
77
|
-
if param.grad is None: raise RuntimeError('init is set to "grad, but param.grad is None"')
|
|
78
|
-
return param.grad.detach().clone()
|
|
79
|
-
|
|
80
73
|
if isinstance(init, (list,tuple)):
|
|
81
74
|
if i is None: raise RuntimeError(f'init is per-parameter ({type(init)}) but parameter index i is None')
|
|
82
|
-
return _make_initial_state_value(
|
|
75
|
+
return _make_initial_state_value(tensor, init[i], None)
|
|
83
76
|
|
|
84
77
|
return init
|
|
85
78
|
|
|
@@ -133,72 +126,6 @@ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], param
|
|
|
133
126
|
return values
|
|
134
127
|
|
|
135
128
|
|
|
136
|
-
class Optimizer(torch.optim.Optimizer, ABC):
|
|
137
|
-
"""subclass of torch.optim.Optimizer with some helper methods for fast experimentation, it's not used anywhere in torchzero.
|
|
138
|
-
|
|
139
|
-
Args:
|
|
140
|
-
params (iterable): an iterable of :class:`torch.Tensor` s or
|
|
141
|
-
:class:`dict` s. Specifies what Tensors should be optimized.
|
|
142
|
-
defaults (dict | None): a dict containing default values of optimization
|
|
143
|
-
options (used when a parameter group doesn't specify them).
|
|
144
|
-
"""
|
|
145
|
-
def __init__(self, params, defaults: dict[str, Any] | None = None, **_defaults):
|
|
146
|
-
if defaults is None: defaults = {}
|
|
147
|
-
defaults.update(_defaults)
|
|
148
|
-
|
|
149
|
-
super().__init__(params, defaults)
|
|
150
|
-
self.global_state = self.state[self.param_groups[0]['params'][0]]
|
|
151
|
-
"""state of 1st parameter, can be used as global state which is how L-BFGS uses it in pytorch, and there is some kind of good reason to do it like that"""
|
|
152
|
-
|
|
153
|
-
def get_params(self, mode: ParamFilter = 'requires_grad', cls: type[ListLike] = TensorList) -> ListLike:
|
|
154
|
-
return get_params(self.param_groups, mode, cls)
|
|
155
|
-
|
|
156
|
-
@overload
|
|
157
|
-
def group_vals(self, key: str, *,
|
|
158
|
-
mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> ListLike: ...
|
|
159
|
-
@overload
|
|
160
|
-
def group_vals(self, key: list[str] | tuple[str,...], *,
|
|
161
|
-
mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> list[ListLike]: ...
|
|
162
|
-
@overload
|
|
163
|
-
def group_vals(self, key: str, key2: str, *keys: str,
|
|
164
|
-
mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> list[ListLike]: ...
|
|
165
|
-
|
|
166
|
-
def group_vals(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
|
|
167
|
-
mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> ListLike | list[ListLike]:
|
|
168
|
-
return get_group_vals(self.param_groups, key, key2, *keys, mode = mode, cls = cls) # pyright:ignore[reportArgumentType]
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
@overload
|
|
172
|
-
def state_vals(self, key: str, *,
|
|
173
|
-
init: Init = torch.zeros_like,
|
|
174
|
-
mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
|
|
175
|
-
cls: type[ListLike] = TensorList) -> ListLike: ...
|
|
176
|
-
@overload
|
|
177
|
-
def state_vals(self, key: list[str] | tuple[str,...], *,
|
|
178
|
-
init: Init | Sequence[Init] = torch.zeros_like,
|
|
179
|
-
mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
|
|
180
|
-
cls: type[ListLike] = TensorList) -> list[ListLike]: ...
|
|
181
|
-
@overload
|
|
182
|
-
def state_vals(self, key: str, key2: str, *keys: str,
|
|
183
|
-
init: Init | Sequence[Init] = torch.zeros_like,
|
|
184
|
-
mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
|
|
185
|
-
cls: type[ListLike] = TensorList) -> list[ListLike]: ...
|
|
186
|
-
|
|
187
|
-
def state_vals(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
|
|
188
|
-
init: Init | Sequence[Init] = torch.zeros_like,
|
|
189
|
-
mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
|
|
190
|
-
cls: type[ListLike] = TensorList) -> ListLike | list[ListLike]:
|
|
191
|
-
|
|
192
|
-
if isinstance(mode, (list,tuple)): params = mode
|
|
193
|
-
else: params = self.get_params(mode)
|
|
194
|
-
|
|
195
|
-
return get_state_vals(self.state, params, key, key2, *keys, init = init, cls = cls) # type:ignore[reportArgumentType]
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
# shut up pylance
|
|
199
|
-
@abstractmethod
|
|
200
|
-
def step(self, closure) -> Any: ... # pylint:disable=signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
|
|
201
|
-
|
|
202
129
|
def zero_grad_(params: Iterable[torch.Tensor], set_to_none):
|
|
203
130
|
if set_to_none:
|
|
204
131
|
for p in params:
|
torchzero/utils/python_tools.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import importlib
|
|
1
2
|
import functools
|
|
2
3
|
import operator
|
|
3
4
|
from typing import Any, TypeVar, overload
|
|
@@ -40,6 +41,11 @@ def generic_ne(x: int | float | Iterable[int | float], y: int | float | Iterable
|
|
|
40
41
|
return any(i!=y for i in x)
|
|
41
42
|
return any(i!=j for i,j in zip(x,y))
|
|
42
43
|
|
|
44
|
+
def generic_is_none(x: Any | Iterable[Any]):
|
|
45
|
+
"""returns True if x is None or iterable with all elements set to None"""
|
|
46
|
+
if x is None: return True
|
|
47
|
+
if isinstance(x, Iterable): return all(i is None for i in x)
|
|
48
|
+
return False
|
|
43
49
|
|
|
44
50
|
def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
|
|
45
51
|
"""If `other` is list/tuple, applies `fn` to self zipped with `other`.
|
|
@@ -68,3 +74,28 @@ def safe_dict_update_(d1_:dict, d2:dict):
|
|
|
68
74
|
if len(inter) > 0: raise RuntimeError(f"Duplicate keys {inter}")
|
|
69
75
|
d1_.update(d2)
|
|
70
76
|
|
|
77
|
+
# lazy loader from https://stackoverflow.com/a/78312674/15673832
|
|
78
|
+
class LazyLoader:
|
|
79
|
+
'thin shell class to wrap modules. load real module on first access and pass thru'
|
|
80
|
+
|
|
81
|
+
def __init__(self, modname):
|
|
82
|
+
self._modname = modname
|
|
83
|
+
self._mod = None
|
|
84
|
+
|
|
85
|
+
def __getattr__(self, attr):
|
|
86
|
+
'import module on first attribute access'
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
return getattr(self._mod, attr)
|
|
90
|
+
|
|
91
|
+
except Exception as e :
|
|
92
|
+
if self._mod is None :
|
|
93
|
+
# module is unset, load it
|
|
94
|
+
self._mod = importlib.import_module (self._modname)
|
|
95
|
+
else :
|
|
96
|
+
# module is set, got different exception from getattr (). reraise it
|
|
97
|
+
raise e
|
|
98
|
+
|
|
99
|
+
# retry getattr if module was just loaded for first time
|
|
100
|
+
# call this outside exception handler in case it raises new exception
|
|
101
|
+
return getattr (self._mod, attr)
|
torchzero/utils/tensorlist.py
CHANGED
|
@@ -22,7 +22,6 @@ from typing_extensions import Self, TypeAlias, Unpack
|
|
|
22
22
|
|
|
23
23
|
from .metrics import Metrics, evaluate_metric, calculate_metric_list
|
|
24
24
|
from .numberlist import NumberList, as_numberlist, maybe_numberlist
|
|
25
|
-
from .ops import where_
|
|
26
25
|
from .python_tools import generic_ne, zipmap
|
|
27
26
|
|
|
28
27
|
_Scalar = int | float | bool | complex
|
|
@@ -346,6 +345,10 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
346
345
|
def global_all(self): return builtins.all(self.all())
|
|
347
346
|
def global_numel(self) -> int: return builtins.sum(self.numel())
|
|
348
347
|
|
|
348
|
+
def global_allclose(self, other: _TensorSeq, rtol: float = 0.00001, atol: float = 1e-8, equal_nan: bool = False) -> bool:
|
|
349
|
+
bools = self.zipmap_args(torch.allclose, other, rtol, atol, equal_nan)
|
|
350
|
+
return all(bools)
|
|
351
|
+
|
|
349
352
|
def empty_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.empty_like(i, **kwargs) for i in self)
|
|
350
353
|
def zeros_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.zeros_like(i, **kwargs) for i in self)
|
|
351
354
|
def ones_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.ones_like(i, **kwargs) for i in self)
|
|
@@ -509,7 +512,6 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
509
512
|
torch._foreach_mul_(self, other)
|
|
510
513
|
return self
|
|
511
514
|
|
|
512
|
-
# TODO: benchmark
|
|
513
515
|
def lazy_mul(self, other: int | float | list[int | float] | tuple[int | float], clone=False):
|
|
514
516
|
if generic_ne(other, 1):
|
|
515
517
|
return self * other
|
|
@@ -536,6 +538,13 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
536
538
|
torch._foreach_pow_(self, exponent)
|
|
537
539
|
return self
|
|
538
540
|
|
|
541
|
+
def lazy_pow(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
542
|
+
if generic_ne(other, 1): return self.pow(other)
|
|
543
|
+
return self
|
|
544
|
+
def lazy_pow_(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
545
|
+
if generic_ne(other, 1): return self.pow_(other)
|
|
546
|
+
return self
|
|
547
|
+
|
|
539
548
|
def rpow(self, input: _Scalar | _TensorSeq): return self.__class__(torch._foreach_pow(input, self))
|
|
540
549
|
def rpow_(self, input: _TensorSeq):
|
|
541
550
|
torch._foreach_pow_(input, self)
|
|
@@ -984,9 +993,6 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
984
993
|
def where(self, condition: "torch.Tensor | _TensorSeq", other: _STOrSTSeq):
|
|
985
994
|
"""self where condition is true other otherwise"""
|
|
986
995
|
return self.zipmap_args(_MethodCallerWithArgs('where'), condition, other)
|
|
987
|
-
def where_(self, condition: "torch.Tensor | _TensorSeq", other: "torch.Tensor | _TensorSeq"):
|
|
988
|
-
"""self where condition is true other otherwise"""
|
|
989
|
-
return self.zipmap_args_inplace_(where_, condition, other)
|
|
990
996
|
|
|
991
997
|
def masked_fill(self, mask: "torch.Tensor | _TensorSeq", fill_value: "_Scalar | _ScalarSeq"):
|
|
992
998
|
"""Same as tensor[mask] = value (not in-place), where value must be scalar/scalars"""
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from importlib.util import find_spec
|
|
4
|
+
from typing import TYPE_CHECKING, cast
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from .python_tools import LazyLoader
|
|
9
|
+
|
|
10
|
+
lazy_thoad = LazyLoader("thoad")
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
import thoad
|
|
13
|
+
lazy_thoad = cast(thoad, lazy_thoad)
|
|
14
|
+
|
|
15
|
+
def thoad_single_tensor(
|
|
16
|
+
ctrl: "thoad.Controller",
|
|
17
|
+
params: list[torch.Tensor],
|
|
18
|
+
order: int
|
|
19
|
+
) -> torch.Tensor:
|
|
20
|
+
"""treats params as if they were concatenated into a vector."""
|
|
21
|
+
|
|
22
|
+
if not all(p.requires_grad for p in params):
|
|
23
|
+
raise ValueError("All parameters must have requires_grad=True")
|
|
24
|
+
|
|
25
|
+
if order < 1:
|
|
26
|
+
raise ValueError("Order must be at least 1")
|
|
27
|
+
|
|
28
|
+
# we need parameter sizes and total size N
|
|
29
|
+
# final tensor is (N, N, ..., N) with `order` dimensions.
|
|
30
|
+
param_numels = [p.numel() for p in params]
|
|
31
|
+
total_params = sum(param_numels)
|
|
32
|
+
|
|
33
|
+
final_shape = (total_params,) * order
|
|
34
|
+
p = params[0]
|
|
35
|
+
T = torch.zeros(final_shape, device=p.device, dtype=p.dtype)
|
|
36
|
+
|
|
37
|
+
# start/end indices for each parameter in the flattened vector.
|
|
38
|
+
offsets = torch.cumsum(torch.tensor([0] + param_numels), dim=0)
|
|
39
|
+
|
|
40
|
+
# for order=2 this iterates through (p0,p0), (p0,p1), (p1,p0), (p1,p1), etc.
|
|
41
|
+
param_indices = range(len(params))
|
|
42
|
+
for block_indices in itertools.product(param_indices, repeat=order):
|
|
43
|
+
|
|
44
|
+
block_params = tuple(params[i] for i in block_indices)
|
|
45
|
+
block_tensor, _ = ctrl.fetch_hgrad(variables=block_params) # (1, *p1.shape, *p2.shape, ...).
|
|
46
|
+
block_tensor = block_tensor.squeeze(0) # (*p1.shape, *p2.shape, ...)
|
|
47
|
+
|
|
48
|
+
# convert (*p1.shape, *p2.shape) to (p1.numel(), p2.numel())
|
|
49
|
+
block_flat_shape = tuple(param_numels[i] for i in block_indices)
|
|
50
|
+
block_tensor_flat = block_tensor.reshape(block_flat_shape)
|
|
51
|
+
|
|
52
|
+
# place the flattened block into T
|
|
53
|
+
slicing = tuple(
|
|
54
|
+
slice(offsets[i], offsets[i+1]) for i in block_indices
|
|
55
|
+
)
|
|
56
|
+
T[slicing] = block_tensor_flat
|
|
57
|
+
|
|
58
|
+
ctrl.clear()
|
|
59
|
+
return T
|
|
60
|
+
|
|
61
|
+
def thoad_derivatives(
|
|
62
|
+
ctrl: "thoad.Controller",
|
|
63
|
+
params: list[torch.Tensor],
|
|
64
|
+
order: int,
|
|
65
|
+
):
|
|
66
|
+
"""returns all derivatives up to ``order`` in ascending order, all as single tensors
|
|
67
|
+
as if parameters were concatenated to a vector"""
|
|
68
|
+
return [thoad_single_tensor(ctrl, params, o) for o in range(1, order+1)]
|