torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
torchzero/modules/misc/escape.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
1
4
|
import torch
|
|
2
5
|
|
|
3
|
-
from ...core import Module
|
|
4
|
-
from ...utils import
|
|
6
|
+
from ...core import Modular, Module, Var, Chainable
|
|
7
|
+
from ...utils import NumberList, TensorList
|
|
5
8
|
|
|
6
9
|
|
|
7
10
|
class EscapeAnnealing(Module):
|
|
@@ -42,19 +45,18 @@ class EscapeAnnealing(Module):
|
|
|
42
45
|
if n_bad >= n_tol:
|
|
43
46
|
for i in range(1, max_iter+1):
|
|
44
47
|
alpha = max_region * (i / max_iter)
|
|
45
|
-
pert = params.
|
|
48
|
+
pert = params.sphere_like(radius=alpha)
|
|
46
49
|
|
|
47
50
|
params.add_(pert)
|
|
48
51
|
f_star = closure(False)
|
|
49
52
|
|
|
50
|
-
if f_star < f_0-1e-
|
|
53
|
+
if math.isfinite(f_star) and f_star < f_0-1e-12:
|
|
51
54
|
var.update = None
|
|
52
55
|
var.stop = True
|
|
53
56
|
var.skip_update = True
|
|
54
57
|
return var
|
|
55
58
|
|
|
56
|
-
|
|
57
|
-
params.sub_(pert)
|
|
59
|
+
params.sub_(pert)
|
|
58
60
|
|
|
59
61
|
self.global_state['n_bad'] = 0
|
|
60
|
-
return var
|
|
62
|
+
return var
|
|
@@ -3,46 +3,112 @@ import torch
|
|
|
3
3
|
from ...core import Chainable, Module
|
|
4
4
|
|
|
5
5
|
|
|
6
|
+
# class GradientAccumulation(Module):
|
|
7
|
+
# """Uses :code:`n` steps to accumulate gradients, after :code:`n` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
|
|
8
|
+
|
|
9
|
+
# Accumulating gradients for :code:`n` steps is equivalent to increasing batch size by :code:`n`. Increasing the batch size
|
|
10
|
+
# is more computationally efficient, but sometimes it is not feasible due to memory constraints.
|
|
11
|
+
|
|
12
|
+
# .. note::
|
|
13
|
+
# Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.
|
|
14
|
+
|
|
15
|
+
# Args:
|
|
16
|
+
# modules (Chainable): modules that perform a step every :code:`n` steps using the accumulated gradients.
|
|
17
|
+
# n (int): number of gradients to accumulate.
|
|
18
|
+
# mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
|
|
19
|
+
# stop (bool, optional):
|
|
20
|
+
# this module prevents next modules from stepping unless :code:`n` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.
|
|
21
|
+
|
|
22
|
+
# Examples:
|
|
23
|
+
# Adam with gradients accumulated for 16 batches.
|
|
24
|
+
|
|
25
|
+
# .. code-block:: python
|
|
26
|
+
|
|
27
|
+
# opt = tz.Modular(
|
|
28
|
+
# model.parameters(),
|
|
29
|
+
# tz.m.GradientAccumulation(
|
|
30
|
+
# [tz.m.Adam(), tz.m.LR(1e-2)],
|
|
31
|
+
# n=16
|
|
32
|
+
# )
|
|
33
|
+
# )
|
|
34
|
+
|
|
35
|
+
# """
|
|
36
|
+
# def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
|
|
37
|
+
# defaults = dict(n=n, mean=mean, stop=stop)
|
|
38
|
+
# super().__init__(defaults)
|
|
39
|
+
# self.set_child('modules', modules)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# @torch.no_grad
|
|
43
|
+
# def step(self, var):
|
|
44
|
+
# accumulator = self.get_state(var.params, 'accumulator')
|
|
45
|
+
# settings = self.defaults
|
|
46
|
+
# n = settings['n']; mean = settings['mean']; stop = settings['stop']
|
|
47
|
+
# step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
48
|
+
|
|
49
|
+
# # add update to accumulator
|
|
50
|
+
# torch._foreach_add_(accumulator, var.get_update())
|
|
51
|
+
|
|
52
|
+
# # step with accumulated updates
|
|
53
|
+
# if step % n == 0:
|
|
54
|
+
# if mean:
|
|
55
|
+
# torch._foreach_div_(accumulator, n)
|
|
56
|
+
|
|
57
|
+
# var.update = [a.clone() for a in accumulator]
|
|
58
|
+
# var = self.children['modules'].step(var)
|
|
59
|
+
|
|
60
|
+
# # zero accumulator
|
|
61
|
+
# torch._foreach_zero_(accumulator)
|
|
62
|
+
|
|
63
|
+
# else:
|
|
64
|
+
# # prevent update
|
|
65
|
+
# if stop:
|
|
66
|
+
# var.update = None
|
|
67
|
+
# var.stop=True
|
|
68
|
+
# var.skip_update=True
|
|
69
|
+
|
|
70
|
+
# return var
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
|
|
6
75
|
class GradientAccumulation(Module):
|
|
7
|
-
"""Uses
|
|
76
|
+
"""Uses ``n`` steps to accumulate gradients, after ``n`` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
|
|
8
77
|
|
|
9
|
-
Accumulating gradients for
|
|
78
|
+
Accumulating gradients for ``n`` steps is equivalent to increasing batch size by ``n``. Increasing the batch size
|
|
10
79
|
is more computationally efficient, but sometimes it is not feasible due to memory constraints.
|
|
11
80
|
|
|
12
|
-
|
|
81
|
+
Note:
|
|
13
82
|
Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.
|
|
14
83
|
|
|
15
84
|
Args:
|
|
16
|
-
modules (Chainable): modules that perform a step every :code:`n` steps using the accumulated gradients.
|
|
17
85
|
n (int): number of gradients to accumulate.
|
|
18
86
|
mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
|
|
19
87
|
stop (bool, optional):
|
|
20
|
-
this module prevents next modules from stepping unless
|
|
21
|
-
|
|
22
|
-
Examples:
|
|
23
|
-
Adam with gradients accumulated for 16 batches.
|
|
88
|
+
this module prevents next modules from stepping unless ``n`` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.
|
|
24
89
|
|
|
25
|
-
|
|
90
|
+
## Examples:
|
|
26
91
|
|
|
27
|
-
|
|
28
|
-
model.parameters(),
|
|
29
|
-
tz.m.GradientAccumulation(
|
|
30
|
-
modules=[tz.m.Adam(), tz.m.LR(1e-2)],
|
|
31
|
-
n=16
|
|
32
|
-
)
|
|
33
|
-
)
|
|
92
|
+
Adam with gradients accumulated for 16 batches.
|
|
34
93
|
|
|
94
|
+
```python
|
|
95
|
+
opt = tz.Modular(
|
|
96
|
+
model.parameters(),
|
|
97
|
+
tz.m.GradientAccumulation(),
|
|
98
|
+
tz.m.Adam(),
|
|
99
|
+
tz.m.LR(1e-2),
|
|
100
|
+
)
|
|
101
|
+
```
|
|
35
102
|
"""
|
|
36
|
-
def __init__(self,
|
|
103
|
+
def __init__(self, n: int, mean=True, stop=True):
|
|
37
104
|
defaults = dict(n=n, mean=mean, stop=stop)
|
|
38
105
|
super().__init__(defaults)
|
|
39
|
-
self.set_child('modules', modules)
|
|
40
106
|
|
|
41
107
|
|
|
42
108
|
@torch.no_grad
|
|
43
109
|
def step(self, var):
|
|
44
110
|
accumulator = self.get_state(var.params, 'accumulator')
|
|
45
|
-
settings = self.
|
|
111
|
+
settings = self.defaults
|
|
46
112
|
n = settings['n']; mean = settings['mean']; stop = settings['stop']
|
|
47
113
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
48
114
|
|
|
@@ -54,15 +120,15 @@ class GradientAccumulation(Module):
|
|
|
54
120
|
if mean:
|
|
55
121
|
torch._foreach_div_(accumulator, n)
|
|
56
122
|
|
|
57
|
-
var.update =
|
|
58
|
-
var = self.children['modules'].step(var)
|
|
123
|
+
var.update = accumulator
|
|
59
124
|
|
|
60
125
|
# zero accumulator
|
|
61
|
-
|
|
126
|
+
self.clear_state_keys('accumulator')
|
|
62
127
|
|
|
63
128
|
else:
|
|
64
129
|
# prevent update
|
|
65
130
|
if stop:
|
|
131
|
+
var.update = None
|
|
66
132
|
var.stop=True
|
|
67
133
|
var.skip_update=True
|
|
68
134
|
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
import torch
|
|
4
|
+
from ...core import Module
|
|
5
|
+
from ...core import Chainable
|
|
6
|
+
|
|
7
|
+
class HomotopyBase(Module):
|
|
8
|
+
def __init__(self, defaults: dict | None = None):
|
|
9
|
+
super().__init__(defaults)
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def loss_transform(self, loss: torch.Tensor) -> torch.Tensor:
|
|
13
|
+
"""transform the loss"""
|
|
14
|
+
|
|
15
|
+
@torch.no_grad
|
|
16
|
+
def step(self, var):
|
|
17
|
+
if var.loss is not None:
|
|
18
|
+
var.loss = self.loss_transform(var.loss)
|
|
19
|
+
|
|
20
|
+
closure = var.closure
|
|
21
|
+
if closure is None: raise RuntimeError("SquareHomotopy requires closure")
|
|
22
|
+
|
|
23
|
+
def homotopy_closure(backward=True):
|
|
24
|
+
if backward:
|
|
25
|
+
with torch.enable_grad():
|
|
26
|
+
loss = self.loss_transform(closure(False))
|
|
27
|
+
grad = torch.autograd.grad(loss, var.params, allow_unused=True)
|
|
28
|
+
for p,g in zip(var.params, grad):
|
|
29
|
+
p.grad = g
|
|
30
|
+
else:
|
|
31
|
+
loss = self.loss_transform(closure(False))
|
|
32
|
+
|
|
33
|
+
return loss
|
|
34
|
+
|
|
35
|
+
var.closure = homotopy_closure
|
|
36
|
+
return var
|
|
37
|
+
|
|
38
|
+
class SquareHomotopy(HomotopyBase):
|
|
39
|
+
def __init__(self): super().__init__()
|
|
40
|
+
def loss_transform(self, loss): return loss.square().copysign(loss)
|
|
41
|
+
|
|
42
|
+
class SqrtHomotopy(HomotopyBase):
|
|
43
|
+
def __init__(self): super().__init__()
|
|
44
|
+
def loss_transform(self, loss): return (loss+1e-12).sqrt()
|
|
45
|
+
|
|
46
|
+
class ExpHomotopy(HomotopyBase):
|
|
47
|
+
def __init__(self): super().__init__()
|
|
48
|
+
def loss_transform(self, loss): return loss.exp()
|
|
49
|
+
|
|
50
|
+
class LogHomotopy(HomotopyBase):
|
|
51
|
+
def __init__(self): super().__init__()
|
|
52
|
+
def loss_transform(self, loss): return (loss+1e-12).log()
|
|
53
|
+
|
|
54
|
+
class LambdaHomotopy(HomotopyBase):
|
|
55
|
+
def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
|
|
56
|
+
defaults = dict(fn=fn)
|
|
57
|
+
super().__init__(defaults)
|
|
58
|
+
|
|
59
|
+
def loss_transform(self, loss): return self.defaults['fn'](loss)
|
torchzero/modules/misc/misc.py
CHANGED
|
@@ -1,12 +1,22 @@
|
|
|
1
1
|
from collections import deque
|
|
2
|
-
from collections.abc import Iterable
|
|
2
|
+
from collections.abc import Iterable, Sequence
|
|
3
|
+
from functools import partial
|
|
3
4
|
from operator import itemgetter
|
|
4
5
|
from typing import Literal
|
|
5
6
|
|
|
6
7
|
import torch
|
|
7
8
|
|
|
8
9
|
from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
|
|
9
|
-
from ...utils import
|
|
10
|
+
from ...utils import (
|
|
11
|
+
Distributions,
|
|
12
|
+
Metrics,
|
|
13
|
+
NumberList,
|
|
14
|
+
TensorList,
|
|
15
|
+
set_storage_,
|
|
16
|
+
tofloat,
|
|
17
|
+
unpack_dicts,
|
|
18
|
+
unpack_states,
|
|
19
|
+
)
|
|
10
20
|
|
|
11
21
|
|
|
12
22
|
class Previous(TensorwiseTransform):
|
|
@@ -139,7 +149,7 @@ class UpdateSign(Transform):
|
|
|
139
149
|
|
|
140
150
|
class GraftToGrad(Transform):
|
|
141
151
|
"""Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient."""
|
|
142
|
-
def __init__(self, tensorwise:bool=False, ord:
|
|
152
|
+
def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6, target: Target = 'update'):
|
|
143
153
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
144
154
|
super().__init__(defaults, uses_grad=True, target=target)
|
|
145
155
|
|
|
@@ -151,7 +161,7 @@ class GraftToGrad(Transform):
|
|
|
151
161
|
|
|
152
162
|
class GraftGradToUpdate(Transform):
|
|
153
163
|
"""Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update."""
|
|
154
|
-
def __init__(self, tensorwise:bool=False, ord:
|
|
164
|
+
def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6, target: Target = 'update'):
|
|
155
165
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
156
166
|
super().__init__(defaults, uses_grad=True, target=target)
|
|
157
167
|
|
|
@@ -164,7 +174,7 @@ class GraftGradToUpdate(Transform):
|
|
|
164
174
|
|
|
165
175
|
class GraftToParams(Transform):
|
|
166
176
|
"""Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than :code:`eps`."""
|
|
167
|
-
def __init__(self, tensorwise:bool=False, ord:
|
|
177
|
+
def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-4, target: Target = 'update'):
|
|
168
178
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
169
179
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
170
180
|
|
|
@@ -194,7 +204,7 @@ class FillLoss(Module):
|
|
|
194
204
|
@torch.no_grad
|
|
195
205
|
def step(self, var):
|
|
196
206
|
alpha = self.get_settings(var.params, 'alpha')
|
|
197
|
-
loss = var.get_loss(backward=self.
|
|
207
|
+
loss = var.get_loss(backward=self.defaults['backward'])
|
|
198
208
|
var.update = [torch.full_like(p, loss*a) for p,a in zip(var.params, alpha)]
|
|
199
209
|
return var
|
|
200
210
|
|
|
@@ -207,7 +217,7 @@ class MulByLoss(Module):
|
|
|
207
217
|
@torch.no_grad
|
|
208
218
|
def step(self, var):
|
|
209
219
|
alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
|
|
210
|
-
loss = var.get_loss(backward=self.
|
|
220
|
+
loss = var.get_loss(backward=self.defaults['backward'])
|
|
211
221
|
mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
212
222
|
torch._foreach_mul_(var.update, mul)
|
|
213
223
|
return var
|
|
@@ -221,7 +231,7 @@ class DivByLoss(Module):
|
|
|
221
231
|
@torch.no_grad
|
|
222
232
|
def step(self, var):
|
|
223
233
|
alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
|
|
224
|
-
loss = var.get_loss(backward=self.
|
|
234
|
+
loss = var.get_loss(backward=self.defaults['backward'])
|
|
225
235
|
mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
226
236
|
torch._foreach_div_(var.update, mul)
|
|
227
237
|
return var
|
|
@@ -229,15 +239,14 @@ class DivByLoss(Module):
|
|
|
229
239
|
|
|
230
240
|
class NoiseSign(Transform):
|
|
231
241
|
"""Outputs random tensors with sign copied from the update."""
|
|
232
|
-
def __init__(self, distribution:Distributions = 'normal',
|
|
233
|
-
defaults = dict(distribution=distribution,
|
|
242
|
+
def __init__(self, distribution:Distributions = 'normal', variance:float | None = None):
|
|
243
|
+
defaults = dict(distribution=distribution, variance=variance)
|
|
234
244
|
super().__init__(defaults, uses_grad=False)
|
|
235
245
|
|
|
236
246
|
@torch.no_grad
|
|
237
247
|
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
return TensorList(tensors).sample_like(alpha, distribution).copysign_(tensors)
|
|
248
|
+
variance = unpack_dicts(settings, 'variance')
|
|
249
|
+
return TensorList(tensors).sample_like(settings[0]['distribution'], variance=variance).copysign_(tensors)
|
|
241
250
|
|
|
242
251
|
class HpuEstimate(Transform):
|
|
243
252
|
"""returns ``y/||s||``, where ``y`` is difference between current and previous update (gradient), ``s`` is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update."""
|
|
@@ -257,7 +266,7 @@ class HpuEstimate(Transform):
|
|
|
257
266
|
for p, c in zip(prev_params, params): p.copy_(c)
|
|
258
267
|
for p, c in zip(prev_update, tensors): p.copy_(c)
|
|
259
268
|
torch._foreach_div_(y, torch.linalg.norm(torch.cat([t.ravel() for t in s])).clip(min=1e-8)) # pylint:disable=not-callable
|
|
260
|
-
self.store(params,
|
|
269
|
+
self.store(params, 'y', y)
|
|
261
270
|
|
|
262
271
|
@torch.no_grad
|
|
263
272
|
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
@@ -295,7 +304,7 @@ class RandomHvp(Module):
|
|
|
295
304
|
|
|
296
305
|
rgrad = None
|
|
297
306
|
for i in range(n_samples):
|
|
298
|
-
u = params.sample_like(distribution=distribution)
|
|
307
|
+
u = params.sample_like(distribution=distribution, variance=1)
|
|
299
308
|
|
|
300
309
|
Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
|
|
301
310
|
h=h, normalize=True, retain_grad=i < n_samples-1)
|
|
@@ -314,3 +323,61 @@ class RandomHvp(Module):
|
|
|
314
323
|
|
|
315
324
|
var.update = list(D)
|
|
316
325
|
return var
|
|
326
|
+
|
|
327
|
+
@torch.no_grad
|
|
328
|
+
def _load_best_parameters(params: Sequence[torch.Tensor], best_params: Sequence[torch.Tensor]):
|
|
329
|
+
for p_cur, p_best in zip(params, best_params):
|
|
330
|
+
set_storage_(p_cur, p_best)
|
|
331
|
+
|
|
332
|
+
class SaveBest(Module):
|
|
333
|
+
"""Saves best parameters found so far, ones that have lowest loss. Put this as the last module.
|
|
334
|
+
|
|
335
|
+
Adds the following attrs:
|
|
336
|
+
|
|
337
|
+
- ``best_params`` - a list of tensors with best parameters.
|
|
338
|
+
- ``best_loss`` - loss value with ``best_params``.
|
|
339
|
+
- ``load_best_parameters`` - a function that sets parameters to the best parameters./
|
|
340
|
+
|
|
341
|
+
## Examples
|
|
342
|
+
```python
|
|
343
|
+
def rosenbrock(x, y):
|
|
344
|
+
return (1 - x)**2 + (100 * (y - x**2))**2
|
|
345
|
+
|
|
346
|
+
xy = torch.tensor((-1.1, 2.5), requires_grad=True)
|
|
347
|
+
opt = tz.Modular(
|
|
348
|
+
[xy],
|
|
349
|
+
tz.m.NAG(0.999),
|
|
350
|
+
tz.m.LR(1e-6),
|
|
351
|
+
tz.m.SaveBest()
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# optimize for 1000 steps
|
|
355
|
+
for i in range(1000):
|
|
356
|
+
loss = rosenbrock(*xy)
|
|
357
|
+
opt.zero_grad()
|
|
358
|
+
loss.backward()
|
|
359
|
+
opt.step(loss=loss) # SaveBest needs closure or loss
|
|
360
|
+
|
|
361
|
+
# NAG overshot, but we saved the best params
|
|
362
|
+
print(f'{rosenbrock(*xy) = }') # >> 3.6583
|
|
363
|
+
print(f"{opt.attrs['best_loss'] = }") # >> 0.000627
|
|
364
|
+
|
|
365
|
+
# load best parameters
|
|
366
|
+
opt.attrs['load_best_params']()
|
|
367
|
+
print(f'{rosenbrock(*xy) = }') # >> 0.000627
|
|
368
|
+
"""
|
|
369
|
+
def __init__(self):
|
|
370
|
+
super().__init__()
|
|
371
|
+
|
|
372
|
+
@torch.no_grad
|
|
373
|
+
def step(self, var):
|
|
374
|
+
loss = tofloat(var.get_loss(False))
|
|
375
|
+
lowest_loss = self.global_state.get('lowest_loss', float("inf"))
|
|
376
|
+
|
|
377
|
+
if loss < lowest_loss:
|
|
378
|
+
self.global_state['lowest_loss'] = loss
|
|
379
|
+
best_params = var.attrs['best_params'] = [p.clone() for p in var.params]
|
|
380
|
+
var.attrs['best_loss'] = loss
|
|
381
|
+
var.attrs['load_best_params'] = partial(_load_best_parameters, params=var.params, best_params=best_params)
|
|
382
|
+
|
|
383
|
+
return var
|
|
@@ -97,7 +97,7 @@ class NegateOnLossIncrease(Module):
|
|
|
97
97
|
def step(self, var):
|
|
98
98
|
closure = var.closure
|
|
99
99
|
if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
|
|
100
|
-
backtrack = self.
|
|
100
|
+
backtrack = self.defaults['backtrack']
|
|
101
101
|
|
|
102
102
|
update = var.get_update()
|
|
103
103
|
f_0 = var.get_loss(backward=False)
|
|
@@ -123,36 +123,72 @@ class NegateOnLossIncrease(Module):
|
|
|
123
123
|
|
|
124
124
|
|
|
125
125
|
class Online(Module):
|
|
126
|
-
"""Allows certain modules to be used for mini-batch optimization.
|
|
127
|
-
|
|
126
|
+
"""Allows certain modules to be used for mini-batch optimization.
|
|
127
|
+
|
|
128
|
+
Examples:
|
|
129
|
+
|
|
130
|
+
Online L-BFGS with Backtracking line search
|
|
131
|
+
```python
|
|
132
|
+
opt = tz.Modular(
|
|
133
|
+
model.parameters(),
|
|
134
|
+
tz.m.Online(tz.m.LBFGS()),
|
|
135
|
+
tz.m.Backtracking()
|
|
136
|
+
)
|
|
137
|
+
```
|
|
138
|
+
|
|
139
|
+
Online L-BFGS trust region
|
|
140
|
+
```python
|
|
141
|
+
opt = tz.Modular(
|
|
142
|
+
model.parameters(),
|
|
143
|
+
tz.m.TrustCG(tz.m.Online(tz.m.LBFGS()))
|
|
144
|
+
)
|
|
145
|
+
```
|
|
146
|
+
|
|
147
|
+
"""
|
|
148
|
+
def __init__(self, *modules: Module,):
|
|
128
149
|
super().__init__()
|
|
129
150
|
|
|
130
|
-
self.set_child('module',
|
|
151
|
+
self.set_child('module', modules)
|
|
131
152
|
|
|
132
153
|
@torch.no_grad
|
|
133
|
-
def
|
|
154
|
+
def update(self, var):
|
|
134
155
|
closure = var.closure
|
|
135
156
|
if closure is None: raise ValueError("Closure must be passed for Online")
|
|
157
|
+
|
|
136
158
|
step = self.global_state.get('step', 0) + 1
|
|
137
159
|
self.global_state['step'] = step
|
|
160
|
+
|
|
138
161
|
params = TensorList(var.params)
|
|
139
162
|
p_cur = params.clone()
|
|
140
163
|
p_prev = self.get_state(params, 'p_prev', cls=TensorList)
|
|
164
|
+
|
|
141
165
|
module = self.children['module']
|
|
166
|
+
var_c = var.clone(clone_update=False)
|
|
142
167
|
|
|
168
|
+
# on 1st step just step and store previous params
|
|
143
169
|
if step == 1:
|
|
144
|
-
var = module.step(var.clone(clone_update=False))
|
|
145
|
-
|
|
146
170
|
p_prev.copy_(params)
|
|
147
|
-
return var
|
|
148
171
|
|
|
149
|
-
|
|
172
|
+
module.update(var_c)
|
|
173
|
+
var.update_attrs_from_clone_(var_c)
|
|
174
|
+
return
|
|
175
|
+
|
|
176
|
+
# restore previous params and update
|
|
150
177
|
var_prev = Var(params=params, closure=closure, model=var.model, current_step=var.current_step)
|
|
151
178
|
params.set_(p_prev)
|
|
152
179
|
module.reset_for_online()
|
|
153
180
|
module.update(var_prev)
|
|
154
181
|
|
|
155
|
-
# restore current params
|
|
182
|
+
# restore current params and update
|
|
156
183
|
params.set_(p_cur)
|
|
157
184
|
p_prev.copy_(params)
|
|
158
|
-
|
|
185
|
+
module.update(var_c)
|
|
186
|
+
var.update_attrs_from_clone_(var_c)
|
|
187
|
+
|
|
188
|
+
@torch.no_grad
|
|
189
|
+
def apply(self, var):
|
|
190
|
+
module = self.children['module']
|
|
191
|
+
return module.apply(var.clone(clone_update=False))
|
|
192
|
+
|
|
193
|
+
def get_H(self, var):
|
|
194
|
+
return self.children['module'].get_H(var)
|
|
@@ -1,12 +1,8 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
from collections.abc import Iterable
|
|
3
|
-
from operator import itemgetter
|
|
4
|
-
from typing import Literal
|
|
5
|
-
|
|
6
1
|
import torch
|
|
7
2
|
|
|
8
|
-
from ...core import Chainable, Module, Target,
|
|
9
|
-
from ...
|
|
3
|
+
from ...core import Chainable, Module, Target, Transform
|
|
4
|
+
from ...core.reformulation import Reformulation
|
|
5
|
+
from ...utils import Distributions, NumberList, TensorList
|
|
10
6
|
|
|
11
7
|
|
|
12
8
|
class Dropout(Transform):
|
|
@@ -121,8 +117,8 @@ class PerturbWeights(Module):
|
|
|
121
117
|
Args:
|
|
122
118
|
alpha (float, optional): multiplier for perturbation magnitude. Defaults to 0.1.
|
|
123
119
|
relative (bool, optional): whether to multiply perturbation by mean absolute value of the parameter. Defaults to True.
|
|
124
|
-
|
|
125
|
-
|
|
120
|
+
distribution (bool, optional):
|
|
121
|
+
distribution of the random perturbation. Defaults to False.
|
|
126
122
|
"""
|
|
127
123
|
def __init__(self, alpha: float = 0.1, relative:bool=True, distribution:Distributions = 'normal'):
|
|
128
124
|
defaults = dict(alpha=alpha, relative=relative, distribution=distribution, perturb=True)
|
torchzero/modules/misc/split.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
|
|
1
|
+
import warnings
|
|
2
|
+
from collections.abc import Callable, Sequence, Iterable
|
|
2
3
|
from typing import cast
|
|
3
4
|
|
|
4
5
|
import torch
|
|
@@ -22,59 +23,78 @@ def _split(
|
|
|
22
23
|
if var.update is not None:
|
|
23
24
|
split_update = [u for i,u in enumerate(var.update) if i in idxs]
|
|
24
25
|
|
|
25
|
-
split_var = var.clone(clone_update=False)
|
|
26
|
+
split_var = var.clone(clone_update=False, parent=var)
|
|
26
27
|
split_var.params = split_params
|
|
27
28
|
split_var.grad = split_grad
|
|
28
29
|
split_var.update = split_update
|
|
29
30
|
|
|
30
31
|
split_var = module.step(split_var)
|
|
31
32
|
|
|
32
|
-
|
|
33
|
-
|
|
33
|
+
# those should be set due to var being parent
|
|
34
|
+
if split_var.grad is not None:
|
|
35
|
+
assert var.grad is not None
|
|
36
|
+
|
|
37
|
+
if split_var.loss is not None:
|
|
38
|
+
assert var.loss is not None
|
|
34
39
|
|
|
35
40
|
if split_var.update is not None:
|
|
36
41
|
|
|
42
|
+
# make sure update is set, it will be filled with ``true`` and ``false`` tensors
|
|
37
43
|
if var.update is None:
|
|
38
44
|
if var.grad is None: var.update = [cast(torch.Tensor, None) for _ in var.params]
|
|
39
45
|
else: var.update = [g.clone() for g in var.grad]
|
|
40
46
|
|
|
47
|
+
# set all tensors from this split
|
|
41
48
|
for idx, u in zip(idxs, split_var.update):
|
|
42
49
|
var.update[idx] = u
|
|
43
50
|
|
|
44
|
-
var.update_attrs_from_clone_(split_var)
|
|
45
51
|
return var
|
|
46
52
|
|
|
47
|
-
|
|
48
|
-
|
|
53
|
+
_SingleFilter = Callable[[torch.Tensor], bool] | torch.Tensor | Iterable[torch.Tensor] | torch.nn.Module | Iterable[torch.nn.Module]
|
|
54
|
+
Filter = _SingleFilter | Iterable[_SingleFilter]
|
|
49
55
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
standard Muon with Adam fallback
|
|
57
|
-
|
|
58
|
-
.. code-block:: python
|
|
59
|
-
|
|
60
|
-
opt = tz.Modular(
|
|
61
|
-
model.head.parameters(),
|
|
62
|
-
tz.m.Split(
|
|
63
|
-
# apply muon only to 2D+ parameters
|
|
64
|
-
filter = lambda t: t.ndim >= 2,
|
|
65
|
-
true = [
|
|
66
|
-
tz.m.HeavyBall(),
|
|
67
|
-
tz.m.Orthogonalize(),
|
|
68
|
-
tz.m.LR(1e-2),
|
|
69
|
-
],
|
|
70
|
-
false = tz.m.Adam()
|
|
71
|
-
),
|
|
72
|
-
tz.m.LR(1e-2)
|
|
73
|
-
)
|
|
56
|
+
def _make_filter(filter: Filter):
|
|
57
|
+
if callable(filter): return filter
|
|
58
|
+
if isinstance(filter, torch.Tensor):
|
|
59
|
+
return lambda x: x is filter
|
|
60
|
+
if isinstance(filter, torch.nn.Module):
|
|
61
|
+
return _make_filter(filter.parameters())
|
|
74
62
|
|
|
63
|
+
# iterable
|
|
64
|
+
filters = [_make_filter(f) for f in filter]
|
|
65
|
+
return lambda x: any(f(x) for f in filters)
|
|
75
66
|
|
|
67
|
+
class Split(Module):
|
|
68
|
+
"""Apply ``true`` modules to all parameters filtered by ``filter``, apply ``false`` modules to all other parameters.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
filter (Filter, bool]):
|
|
72
|
+
a filter that selects tensors to be optimized by ``true``.
|
|
73
|
+
- tensor or iterable of tensors (e.g. ``encoder.parameters()``).
|
|
74
|
+
- function that takes in tensor and outputs a bool (e.g. ``lambda x: x.ndim >= 2``).
|
|
75
|
+
- a sequence of above (acts as "or", so returns true if any of them is true).
|
|
76
|
+
|
|
77
|
+
true (Chainable | None): modules that are applied to tensors where ``filter`` is ``True``.
|
|
78
|
+
false (Chainable | None): modules that are applied to tensors where ``filter`` is ``False``.
|
|
79
|
+
|
|
80
|
+
### Examples:
|
|
81
|
+
|
|
82
|
+
Muon with Adam fallback using same hyperparams as https://github.com/KellerJordan/Muon
|
|
83
|
+
|
|
84
|
+
```python
|
|
85
|
+
opt = tz.Modular(
|
|
86
|
+
model.parameters(),
|
|
87
|
+
tz.m.NAG(0.95),
|
|
88
|
+
tz.m.Split(
|
|
89
|
+
lambda p: p.ndim >= 2,
|
|
90
|
+
true = tz.m.Orthogonalize(),
|
|
91
|
+
false = [tz.m.Adam(0.9, 0.95), tz.m.Mul(1/66)],
|
|
92
|
+
),
|
|
93
|
+
tz.m.LR(1e-2),
|
|
94
|
+
)
|
|
95
|
+
```
|
|
76
96
|
"""
|
|
77
|
-
def __init__(self, filter:
|
|
97
|
+
def __init__(self, filter: Filter, true: Chainable | None, false: Chainable | None):
|
|
78
98
|
defaults = dict(filter=filter)
|
|
79
99
|
super().__init__(defaults)
|
|
80
100
|
|
|
@@ -84,7 +104,7 @@ class Split(Module):
|
|
|
84
104
|
def step(self, var):
|
|
85
105
|
|
|
86
106
|
params = var.params
|
|
87
|
-
filter = self.settings[params[0]]['filter']
|
|
107
|
+
filter = _make_filter(self.settings[params[0]]['filter'])
|
|
88
108
|
|
|
89
109
|
true_idxs = []
|
|
90
110
|
false_idxs = []
|
|
@@ -92,11 +112,11 @@ class Split(Module):
|
|
|
92
112
|
if filter(p): true_idxs.append(i)
|
|
93
113
|
else: false_idxs.append(i)
|
|
94
114
|
|
|
95
|
-
if 'true' in self.children:
|
|
115
|
+
if 'true' in self.children and len(true_idxs) > 0:
|
|
96
116
|
true = self.children['true']
|
|
97
117
|
var = _split(true, idxs=true_idxs, params=params, var=var)
|
|
98
118
|
|
|
99
|
-
if 'false' in self.children:
|
|
119
|
+
if 'false' in self.children and len(false_idxs) > 0:
|
|
100
120
|
false = self.children['false']
|
|
101
121
|
var = _split(false, idxs=false_idxs, params=params, var=var)
|
|
102
122
|
|