torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Modular, Module, Var, Chainable
|
|
7
|
+
from ...utils import NumberList, TensorList
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class EscapeAnnealing(Module):
|
|
11
|
+
"""If parameters stop changing, this runs a backward annealing random search"""
|
|
12
|
+
def __init__(self, max_region:float = 1, max_iter:int = 1000, tol=1e-6, n_tol: int = 10):
|
|
13
|
+
defaults = dict(max_region=max_region, max_iter=max_iter, tol=tol, n_tol=n_tol)
|
|
14
|
+
super().__init__(defaults)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@torch.no_grad
|
|
18
|
+
def step(self, var):
|
|
19
|
+
closure = var.closure
|
|
20
|
+
if closure is None: raise RuntimeError("Escape requries closure")
|
|
21
|
+
|
|
22
|
+
params = TensorList(var.params)
|
|
23
|
+
settings = self.settings[params[0]]
|
|
24
|
+
max_region = self.get_settings(params, 'max_region', cls=NumberList)
|
|
25
|
+
max_iter = settings['max_iter']
|
|
26
|
+
tol = settings['tol']
|
|
27
|
+
n_tol = settings['n_tol']
|
|
28
|
+
|
|
29
|
+
n_bad = self.global_state.get('n_bad', 0)
|
|
30
|
+
|
|
31
|
+
prev_params = self.get_state(params, 'prev_params', cls=TensorList)
|
|
32
|
+
diff = params-prev_params
|
|
33
|
+
prev_params.copy_(params)
|
|
34
|
+
|
|
35
|
+
if diff.abs().global_max() <= tol:
|
|
36
|
+
n_bad += 1
|
|
37
|
+
|
|
38
|
+
else:
|
|
39
|
+
n_bad = 0
|
|
40
|
+
|
|
41
|
+
self.global_state['n_bad'] = n_bad
|
|
42
|
+
|
|
43
|
+
# no progress
|
|
44
|
+
f_0 = var.get_loss(False)
|
|
45
|
+
if n_bad >= n_tol:
|
|
46
|
+
for i in range(1, max_iter+1):
|
|
47
|
+
alpha = max_region * (i / max_iter)
|
|
48
|
+
pert = params.sphere_like(radius=alpha)
|
|
49
|
+
|
|
50
|
+
params.add_(pert)
|
|
51
|
+
f_star = closure(False)
|
|
52
|
+
|
|
53
|
+
if math.isfinite(f_star) and f_star < f_0-1e-12:
|
|
54
|
+
var.update = None
|
|
55
|
+
var.stop = True
|
|
56
|
+
var.skip_update = True
|
|
57
|
+
return var
|
|
58
|
+
|
|
59
|
+
params.sub_(pert)
|
|
60
|
+
|
|
61
|
+
self.global_state['n_bad'] = 0
|
|
62
|
+
return var
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Chainable, Module
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# class GradientAccumulation(Module):
|
|
7
|
+
# """Uses :code:`n` steps to accumulate gradients, after :code:`n` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
|
|
8
|
+
|
|
9
|
+
# Accumulating gradients for :code:`n` steps is equivalent to increasing batch size by :code:`n`. Increasing the batch size
|
|
10
|
+
# is more computationally efficient, but sometimes it is not feasible due to memory constraints.
|
|
11
|
+
|
|
12
|
+
# .. note::
|
|
13
|
+
# Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.
|
|
14
|
+
|
|
15
|
+
# Args:
|
|
16
|
+
# modules (Chainable): modules that perform a step every :code:`n` steps using the accumulated gradients.
|
|
17
|
+
# n (int): number of gradients to accumulate.
|
|
18
|
+
# mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
|
|
19
|
+
# stop (bool, optional):
|
|
20
|
+
# this module prevents next modules from stepping unless :code:`n` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.
|
|
21
|
+
|
|
22
|
+
# Examples:
|
|
23
|
+
# Adam with gradients accumulated for 16 batches.
|
|
24
|
+
|
|
25
|
+
# .. code-block:: python
|
|
26
|
+
|
|
27
|
+
# opt = tz.Modular(
|
|
28
|
+
# model.parameters(),
|
|
29
|
+
# tz.m.GradientAccumulation(
|
|
30
|
+
# [tz.m.Adam(), tz.m.LR(1e-2)],
|
|
31
|
+
# n=16
|
|
32
|
+
# )
|
|
33
|
+
# )
|
|
34
|
+
|
|
35
|
+
# """
|
|
36
|
+
# def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
|
|
37
|
+
# defaults = dict(n=n, mean=mean, stop=stop)
|
|
38
|
+
# super().__init__(defaults)
|
|
39
|
+
# self.set_child('modules', modules)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# @torch.no_grad
|
|
43
|
+
# def step(self, var):
|
|
44
|
+
# accumulator = self.get_state(var.params, 'accumulator')
|
|
45
|
+
# settings = self.defaults
|
|
46
|
+
# n = settings['n']; mean = settings['mean']; stop = settings['stop']
|
|
47
|
+
# step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
48
|
+
|
|
49
|
+
# # add update to accumulator
|
|
50
|
+
# torch._foreach_add_(accumulator, var.get_update())
|
|
51
|
+
|
|
52
|
+
# # step with accumulated updates
|
|
53
|
+
# if step % n == 0:
|
|
54
|
+
# if mean:
|
|
55
|
+
# torch._foreach_div_(accumulator, n)
|
|
56
|
+
|
|
57
|
+
# var.update = [a.clone() for a in accumulator]
|
|
58
|
+
# var = self.children['modules'].step(var)
|
|
59
|
+
|
|
60
|
+
# # zero accumulator
|
|
61
|
+
# torch._foreach_zero_(accumulator)
|
|
62
|
+
|
|
63
|
+
# else:
|
|
64
|
+
# # prevent update
|
|
65
|
+
# if stop:
|
|
66
|
+
# var.update = None
|
|
67
|
+
# var.stop=True
|
|
68
|
+
# var.skip_update=True
|
|
69
|
+
|
|
70
|
+
# return var
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class GradientAccumulation(Module):
|
|
76
|
+
"""Uses ``n`` steps to accumulate gradients, after ``n`` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
|
|
77
|
+
|
|
78
|
+
Accumulating gradients for ``n`` steps is equivalent to increasing batch size by ``n``. Increasing the batch size
|
|
79
|
+
is more computationally efficient, but sometimes it is not feasible due to memory constraints.
|
|
80
|
+
|
|
81
|
+
Note:
|
|
82
|
+
Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
n (int): number of gradients to accumulate.
|
|
86
|
+
mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
|
|
87
|
+
stop (bool, optional):
|
|
88
|
+
this module prevents next modules from stepping unless ``n`` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.
|
|
89
|
+
|
|
90
|
+
## Examples:
|
|
91
|
+
|
|
92
|
+
Adam with gradients accumulated for 16 batches.
|
|
93
|
+
|
|
94
|
+
```python
|
|
95
|
+
opt = tz.Modular(
|
|
96
|
+
model.parameters(),
|
|
97
|
+
tz.m.GradientAccumulation(),
|
|
98
|
+
tz.m.Adam(),
|
|
99
|
+
tz.m.LR(1e-2),
|
|
100
|
+
)
|
|
101
|
+
```
|
|
102
|
+
"""
|
|
103
|
+
def __init__(self, n: int, mean=True, stop=True):
|
|
104
|
+
defaults = dict(n=n, mean=mean, stop=stop)
|
|
105
|
+
super().__init__(defaults)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@torch.no_grad
|
|
109
|
+
def step(self, var):
|
|
110
|
+
accumulator = self.get_state(var.params, 'accumulator')
|
|
111
|
+
settings = self.defaults
|
|
112
|
+
n = settings['n']; mean = settings['mean']; stop = settings['stop']
|
|
113
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
114
|
+
|
|
115
|
+
# add update to accumulator
|
|
116
|
+
torch._foreach_add_(accumulator, var.get_update())
|
|
117
|
+
|
|
118
|
+
# step with accumulated updates
|
|
119
|
+
if step % n == 0:
|
|
120
|
+
if mean:
|
|
121
|
+
torch._foreach_div_(accumulator, n)
|
|
122
|
+
|
|
123
|
+
var.update = accumulator
|
|
124
|
+
|
|
125
|
+
# zero accumulator
|
|
126
|
+
self.clear_state_keys('accumulator')
|
|
127
|
+
|
|
128
|
+
else:
|
|
129
|
+
# prevent update
|
|
130
|
+
if stop:
|
|
131
|
+
var.update = None
|
|
132
|
+
var.stop=True
|
|
133
|
+
var.skip_update=True
|
|
134
|
+
|
|
135
|
+
return var
|
|
136
|
+
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
import torch
|
|
4
|
+
from ...core import Module
|
|
5
|
+
from ...core import Chainable
|
|
6
|
+
|
|
7
|
+
class HomotopyBase(Module):
|
|
8
|
+
def __init__(self, defaults: dict | None = None):
|
|
9
|
+
super().__init__(defaults)
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def loss_transform(self, loss: torch.Tensor) -> torch.Tensor:
|
|
13
|
+
"""transform the loss"""
|
|
14
|
+
|
|
15
|
+
@torch.no_grad
|
|
16
|
+
def step(self, var):
|
|
17
|
+
if var.loss is not None:
|
|
18
|
+
var.loss = self.loss_transform(var.loss)
|
|
19
|
+
|
|
20
|
+
closure = var.closure
|
|
21
|
+
if closure is None: raise RuntimeError("SquareHomotopy requires closure")
|
|
22
|
+
|
|
23
|
+
def homotopy_closure(backward=True):
|
|
24
|
+
if backward:
|
|
25
|
+
with torch.enable_grad():
|
|
26
|
+
loss = self.loss_transform(closure(False))
|
|
27
|
+
grad = torch.autograd.grad(loss, var.params, allow_unused=True)
|
|
28
|
+
for p,g in zip(var.params, grad):
|
|
29
|
+
p.grad = g
|
|
30
|
+
else:
|
|
31
|
+
loss = self.loss_transform(closure(False))
|
|
32
|
+
|
|
33
|
+
return loss
|
|
34
|
+
|
|
35
|
+
var.closure = homotopy_closure
|
|
36
|
+
return var
|
|
37
|
+
|
|
38
|
+
class SquareHomotopy(HomotopyBase):
|
|
39
|
+
def __init__(self): super().__init__()
|
|
40
|
+
def loss_transform(self, loss): return loss.square().copysign(loss)
|
|
41
|
+
|
|
42
|
+
class SqrtHomotopy(HomotopyBase):
|
|
43
|
+
def __init__(self): super().__init__()
|
|
44
|
+
def loss_transform(self, loss): return (loss+1e-12).sqrt()
|
|
45
|
+
|
|
46
|
+
class ExpHomotopy(HomotopyBase):
|
|
47
|
+
def __init__(self): super().__init__()
|
|
48
|
+
def loss_transform(self, loss): return loss.exp()
|
|
49
|
+
|
|
50
|
+
class LogHomotopy(HomotopyBase):
|
|
51
|
+
def __init__(self): super().__init__()
|
|
52
|
+
def loss_transform(self, loss): return (loss+1e-12).log()
|
|
53
|
+
|
|
54
|
+
class LambdaHomotopy(HomotopyBase):
|
|
55
|
+
def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
|
|
56
|
+
defaults = dict(fn=fn)
|
|
57
|
+
super().__init__(defaults)
|
|
58
|
+
|
|
59
|
+
def loss_transform(self, loss): return self.defaults['fn'](loss)
|
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from collections.abc import Iterable, Sequence
|
|
3
|
+
from functools import partial
|
|
4
|
+
from operator import itemgetter
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
|
|
10
|
+
from ...utils import (
|
|
11
|
+
Distributions,
|
|
12
|
+
Metrics,
|
|
13
|
+
NumberList,
|
|
14
|
+
TensorList,
|
|
15
|
+
set_storage_,
|
|
16
|
+
tofloat,
|
|
17
|
+
unpack_dicts,
|
|
18
|
+
unpack_states,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Previous(TensorwiseTransform):
|
|
23
|
+
"""Maintains an update from n steps back, for example if n=1, returns previous update"""
|
|
24
|
+
def __init__(self, n=1, target: Target = 'update'):
|
|
25
|
+
defaults = dict(n=n)
|
|
26
|
+
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@torch.no_grad
|
|
30
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
31
|
+
n = setting['n']
|
|
32
|
+
|
|
33
|
+
if 'history' not in state:
|
|
34
|
+
state['history'] = deque(maxlen=n+1)
|
|
35
|
+
|
|
36
|
+
state['history'].append(tensor)
|
|
37
|
+
|
|
38
|
+
return state['history'][0]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class LastDifference(Transform):
|
|
42
|
+
"""Outputs difference between past two updates."""
|
|
43
|
+
def __init__(self,target: Target = 'update'):
|
|
44
|
+
super().__init__({}, target=target)
|
|
45
|
+
|
|
46
|
+
@torch.no_grad
|
|
47
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
48
|
+
prev_tensors = unpack_states(states, tensors, 'prev_tensors') # initialized to 0
|
|
49
|
+
difference = torch._foreach_sub(tensors, prev_tensors)
|
|
50
|
+
for p, c in zip(prev_tensors, tensors): p.set_(c)
|
|
51
|
+
return difference
|
|
52
|
+
|
|
53
|
+
class LastGradDifference(Module):
|
|
54
|
+
"""Outputs difference between past two gradients."""
|
|
55
|
+
def __init__(self):
|
|
56
|
+
super().__init__({})
|
|
57
|
+
|
|
58
|
+
@torch.no_grad
|
|
59
|
+
def step(self, var):
|
|
60
|
+
grad = var.get_grad()
|
|
61
|
+
prev_grad = self.get_state(var.params, 'prev_grad') # initialized to 0
|
|
62
|
+
difference = torch._foreach_sub(grad, prev_grad)
|
|
63
|
+
for p, c in zip(prev_grad, grad): p.copy_(c)
|
|
64
|
+
var.update = list(difference)
|
|
65
|
+
return var
|
|
66
|
+
|
|
67
|
+
class LastParamDifference(Module):
|
|
68
|
+
"""Outputs difference between past two parameters, which is the effective previous update."""
|
|
69
|
+
def __init__(self):
|
|
70
|
+
super().__init__({})
|
|
71
|
+
|
|
72
|
+
@torch.no_grad
|
|
73
|
+
def step(self, var):
|
|
74
|
+
params = var.params
|
|
75
|
+
prev_params = self.get_state(var.params, 'prev_params') # initialized to 0
|
|
76
|
+
difference = torch._foreach_sub(params, prev_params)
|
|
77
|
+
for p, c in zip(prev_params, params): p.copy_(c)
|
|
78
|
+
var.update = list(difference)
|
|
79
|
+
return var
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class LastProduct(Transform):
|
|
84
|
+
"""Outputs difference between past two updates."""
|
|
85
|
+
def __init__(self,target: Target = 'update'):
|
|
86
|
+
super().__init__({}, uses_grad=False, target=target)
|
|
87
|
+
|
|
88
|
+
@torch.no_grad
|
|
89
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
90
|
+
prev = unpack_states(states, tensors, 'prev', init=torch.ones_like) # initialized to 1 for prod
|
|
91
|
+
prod = torch._foreach_mul(tensors, prev)
|
|
92
|
+
for p, c in zip(prev, tensors): p.set_(c)
|
|
93
|
+
return prod
|
|
94
|
+
|
|
95
|
+
class LastRatio(Transform):
|
|
96
|
+
"""Outputs ratio between past two updates, the numerator is determined by :code:`numerator` argument."""
|
|
97
|
+
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', target: Target = 'update'):
|
|
98
|
+
defaults = dict(numerator=numerator)
|
|
99
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
100
|
+
|
|
101
|
+
@torch.no_grad
|
|
102
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
103
|
+
prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
|
|
104
|
+
numerator = settings[0]['numerator']
|
|
105
|
+
if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
|
|
106
|
+
else: ratio = torch._foreach_div(prev, tensors)
|
|
107
|
+
for p, c in zip(prev, tensors): p.set_(c)
|
|
108
|
+
return ratio
|
|
109
|
+
|
|
110
|
+
class LastAbsoluteRatio(Transform):
|
|
111
|
+
"""Outputs ratio between absolute values of past two updates the numerator is determined by :code:`numerator` argument."""
|
|
112
|
+
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8, target: Target = 'update'):
|
|
113
|
+
defaults = dict(numerator=numerator, eps=eps)
|
|
114
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
115
|
+
|
|
116
|
+
@torch.no_grad
|
|
117
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
118
|
+
prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
|
|
119
|
+
numerator = settings[0]['numerator']
|
|
120
|
+
eps = NumberList(s['eps'] for s in settings)
|
|
121
|
+
|
|
122
|
+
torch._foreach_abs_(tensors)
|
|
123
|
+
torch._foreach_clamp_min_(prev, eps)
|
|
124
|
+
|
|
125
|
+
if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
|
|
126
|
+
else: ratio = torch._foreach_div(prev, tensors)
|
|
127
|
+
for p, c in zip(prev, tensors): p.set_(c)
|
|
128
|
+
return ratio
|
|
129
|
+
|
|
130
|
+
class GradSign(Transform):
|
|
131
|
+
"""Copies gradient sign to update."""
|
|
132
|
+
def __init__(self, target: Target = 'update'):
|
|
133
|
+
super().__init__({}, uses_grad=True, target=target)
|
|
134
|
+
|
|
135
|
+
@torch.no_grad
|
|
136
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
137
|
+
assert grads is not None
|
|
138
|
+
return [t.copysign_(g) for t,g in zip(tensors, grads)]
|
|
139
|
+
|
|
140
|
+
class UpdateSign(Transform):
|
|
141
|
+
"""Outputs gradient with sign copied from the update."""
|
|
142
|
+
def __init__(self, target: Target = 'update'):
|
|
143
|
+
super().__init__({}, uses_grad=True, target=target)
|
|
144
|
+
|
|
145
|
+
@torch.no_grad
|
|
146
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
147
|
+
assert grads is not None
|
|
148
|
+
return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place
|
|
149
|
+
|
|
150
|
+
class GraftToGrad(Transform):
|
|
151
|
+
"""Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient."""
|
|
152
|
+
def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6, target: Target = 'update'):
|
|
153
|
+
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
154
|
+
super().__init__(defaults, uses_grad=True, target=target)
|
|
155
|
+
|
|
156
|
+
@torch.no_grad
|
|
157
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
158
|
+
assert grads is not None
|
|
159
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
160
|
+
return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
161
|
+
|
|
162
|
+
class GraftGradToUpdate(Transform):
|
|
163
|
+
"""Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update."""
|
|
164
|
+
def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6, target: Target = 'update'):
|
|
165
|
+
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
166
|
+
super().__init__(defaults, uses_grad=True, target=target)
|
|
167
|
+
|
|
168
|
+
@torch.no_grad
|
|
169
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
170
|
+
assert grads is not None
|
|
171
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
172
|
+
return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class GraftToParams(Transform):
|
|
176
|
+
"""Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than :code:`eps`."""
|
|
177
|
+
def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-4, target: Target = 'update'):
|
|
178
|
+
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
179
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
180
|
+
|
|
181
|
+
@torch.no_grad
|
|
182
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
183
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
184
|
+
return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
185
|
+
|
|
186
|
+
class Relative(Transform):
|
|
187
|
+
"""Multiplies update by absolute parameter values to make it relative to their magnitude, :code:`min_value` is minimum allowed value to avoid getting stuck at 0."""
|
|
188
|
+
def __init__(self, min_value:float = 1e-4, target: Target = 'update'):
|
|
189
|
+
defaults = dict(min_value=min_value)
|
|
190
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
191
|
+
|
|
192
|
+
@torch.no_grad
|
|
193
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
194
|
+
mul = TensorList(params).abs().clamp_([s['min_value'] for s in settings])
|
|
195
|
+
torch._foreach_mul_(tensors, mul)
|
|
196
|
+
return tensors
|
|
197
|
+
|
|
198
|
+
class FillLoss(Module):
|
|
199
|
+
"""Outputs tensors filled with loss value times :code:`alpha`"""
|
|
200
|
+
def __init__(self, alpha: float = 1, backward: bool = True):
|
|
201
|
+
defaults = dict(alpha=alpha, backward=backward)
|
|
202
|
+
super().__init__(defaults)
|
|
203
|
+
|
|
204
|
+
@torch.no_grad
|
|
205
|
+
def step(self, var):
|
|
206
|
+
alpha = self.get_settings(var.params, 'alpha')
|
|
207
|
+
loss = var.get_loss(backward=self.defaults['backward'])
|
|
208
|
+
var.update = [torch.full_like(p, loss*a) for p,a in zip(var.params, alpha)]
|
|
209
|
+
return var
|
|
210
|
+
|
|
211
|
+
class MulByLoss(Module):
|
|
212
|
+
"""Multiplies update by loss times :code:`alpha`"""
|
|
213
|
+
def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
|
|
214
|
+
defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
|
|
215
|
+
super().__init__(defaults)
|
|
216
|
+
|
|
217
|
+
@torch.no_grad
|
|
218
|
+
def step(self, var):
|
|
219
|
+
alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
|
|
220
|
+
loss = var.get_loss(backward=self.defaults['backward'])
|
|
221
|
+
mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
222
|
+
torch._foreach_mul_(var.update, mul)
|
|
223
|
+
return var
|
|
224
|
+
|
|
225
|
+
class DivByLoss(Module):
|
|
226
|
+
"""Divides update by loss times :code:`alpha`"""
|
|
227
|
+
def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
|
|
228
|
+
defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
|
|
229
|
+
super().__init__(defaults)
|
|
230
|
+
|
|
231
|
+
@torch.no_grad
|
|
232
|
+
def step(self, var):
|
|
233
|
+
alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
|
|
234
|
+
loss = var.get_loss(backward=self.defaults['backward'])
|
|
235
|
+
mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
236
|
+
torch._foreach_div_(var.update, mul)
|
|
237
|
+
return var
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class NoiseSign(Transform):
|
|
241
|
+
"""Outputs random tensors with sign copied from the update."""
|
|
242
|
+
def __init__(self, distribution:Distributions = 'normal', variance:float | None = None):
|
|
243
|
+
defaults = dict(distribution=distribution, variance=variance)
|
|
244
|
+
super().__init__(defaults, uses_grad=False)
|
|
245
|
+
|
|
246
|
+
@torch.no_grad
|
|
247
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
248
|
+
variance = unpack_dicts(settings, 'variance')
|
|
249
|
+
return TensorList(tensors).sample_like(settings[0]['distribution'], variance=variance).copysign_(tensors)
|
|
250
|
+
|
|
251
|
+
class HpuEstimate(Transform):
|
|
252
|
+
"""returns ``y/||s||``, where ``y`` is difference between current and previous update (gradient), ``s`` is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update."""
|
|
253
|
+
def __init__(self):
|
|
254
|
+
defaults = dict()
|
|
255
|
+
super().__init__(defaults, uses_grad=False)
|
|
256
|
+
|
|
257
|
+
def reset_for_online(self):
|
|
258
|
+
super().reset_for_online()
|
|
259
|
+
self.clear_state_keys('prev_params', 'prev_update')
|
|
260
|
+
|
|
261
|
+
@torch.no_grad
|
|
262
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
263
|
+
prev_params, prev_update = self.get_state(params, 'prev_params', 'prev_update') # initialized to 0
|
|
264
|
+
s = torch._foreach_sub(params, prev_params)
|
|
265
|
+
y = torch._foreach_sub(tensors, prev_update)
|
|
266
|
+
for p, c in zip(prev_params, params): p.copy_(c)
|
|
267
|
+
for p, c in zip(prev_update, tensors): p.copy_(c)
|
|
268
|
+
torch._foreach_div_(y, torch.linalg.norm(torch.cat([t.ravel() for t in s])).clip(min=1e-8)) # pylint:disable=not-callable
|
|
269
|
+
self.store(params, 'y', y)
|
|
270
|
+
|
|
271
|
+
@torch.no_grad
|
|
272
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
273
|
+
return [self.state[p]['y'] for p in params]
|
|
274
|
+
|
|
275
|
+
class RandomHvp(Module):
|
|
276
|
+
"""Returns a hessian-vector product with a random vector"""
|
|
277
|
+
|
|
278
|
+
def __init__(
|
|
279
|
+
self,
|
|
280
|
+
n_samples: int = 1,
|
|
281
|
+
distribution: Distributions = "normal",
|
|
282
|
+
update_freq: int = 1,
|
|
283
|
+
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
284
|
+
h=1e-3,
|
|
285
|
+
):
|
|
286
|
+
defaults = dict(n_samples=n_samples, distribution=distribution, hvp_method=hvp_method, h=h, update_freq=update_freq)
|
|
287
|
+
super().__init__(defaults)
|
|
288
|
+
|
|
289
|
+
@torch.no_grad
|
|
290
|
+
def step(self, var):
|
|
291
|
+
params = TensorList(var.params)
|
|
292
|
+
settings = self.settings[params[0]]
|
|
293
|
+
n_samples = settings['n_samples']
|
|
294
|
+
distribution = settings['distribution']
|
|
295
|
+
hvp_method = settings['hvp_method']
|
|
296
|
+
h = settings['h']
|
|
297
|
+
update_freq = settings['update_freq']
|
|
298
|
+
|
|
299
|
+
step = self.global_state.get('step', 0)
|
|
300
|
+
self.global_state['step'] = step + 1
|
|
301
|
+
|
|
302
|
+
D = None
|
|
303
|
+
if step % update_freq == 0:
|
|
304
|
+
|
|
305
|
+
rgrad = None
|
|
306
|
+
for i in range(n_samples):
|
|
307
|
+
u = params.sample_like(distribution=distribution, variance=1)
|
|
308
|
+
|
|
309
|
+
Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
|
|
310
|
+
h=h, normalize=True, retain_grad=i < n_samples-1)
|
|
311
|
+
|
|
312
|
+
if D is None: D = Hvp
|
|
313
|
+
else: torch._foreach_add_(D, Hvp)
|
|
314
|
+
|
|
315
|
+
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
316
|
+
if update_freq != 1:
|
|
317
|
+
assert D is not None
|
|
318
|
+
D_buf = self.get_state(params, "D", cls=TensorList)
|
|
319
|
+
D_buf.set_(D)
|
|
320
|
+
|
|
321
|
+
if D is None:
|
|
322
|
+
D = self.get_state(params, "D", cls=TensorList)
|
|
323
|
+
|
|
324
|
+
var.update = list(D)
|
|
325
|
+
return var
|
|
326
|
+
|
|
327
|
+
@torch.no_grad
|
|
328
|
+
def _load_best_parameters(params: Sequence[torch.Tensor], best_params: Sequence[torch.Tensor]):
|
|
329
|
+
for p_cur, p_best in zip(params, best_params):
|
|
330
|
+
set_storage_(p_cur, p_best)
|
|
331
|
+
|
|
332
|
+
class SaveBest(Module):
|
|
333
|
+
"""Saves best parameters found so far, ones that have lowest loss. Put this as the last module.
|
|
334
|
+
|
|
335
|
+
Adds the following attrs:
|
|
336
|
+
|
|
337
|
+
- ``best_params`` - a list of tensors with best parameters.
|
|
338
|
+
- ``best_loss`` - loss value with ``best_params``.
|
|
339
|
+
- ``load_best_parameters`` - a function that sets parameters to the best parameters./
|
|
340
|
+
|
|
341
|
+
## Examples
|
|
342
|
+
```python
|
|
343
|
+
def rosenbrock(x, y):
|
|
344
|
+
return (1 - x)**2 + (100 * (y - x**2))**2
|
|
345
|
+
|
|
346
|
+
xy = torch.tensor((-1.1, 2.5), requires_grad=True)
|
|
347
|
+
opt = tz.Modular(
|
|
348
|
+
[xy],
|
|
349
|
+
tz.m.NAG(0.999),
|
|
350
|
+
tz.m.LR(1e-6),
|
|
351
|
+
tz.m.SaveBest()
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# optimize for 1000 steps
|
|
355
|
+
for i in range(1000):
|
|
356
|
+
loss = rosenbrock(*xy)
|
|
357
|
+
opt.zero_grad()
|
|
358
|
+
loss.backward()
|
|
359
|
+
opt.step(loss=loss) # SaveBest needs closure or loss
|
|
360
|
+
|
|
361
|
+
# NAG overshot, but we saved the best params
|
|
362
|
+
print(f'{rosenbrock(*xy) = }') # >> 3.6583
|
|
363
|
+
print(f"{opt.attrs['best_loss'] = }") # >> 0.000627
|
|
364
|
+
|
|
365
|
+
# load best parameters
|
|
366
|
+
opt.attrs['load_best_params']()
|
|
367
|
+
print(f'{rosenbrock(*xy) = }') # >> 0.000627
|
|
368
|
+
"""
|
|
369
|
+
def __init__(self):
|
|
370
|
+
super().__init__()
|
|
371
|
+
|
|
372
|
+
@torch.no_grad
|
|
373
|
+
def step(self, var):
|
|
374
|
+
loss = tofloat(var.get_loss(False))
|
|
375
|
+
lowest_loss = self.global_state.get('lowest_loss', float("inf"))
|
|
376
|
+
|
|
377
|
+
if loss < lowest_loss:
|
|
378
|
+
self.global_state['lowest_loss'] = loss
|
|
379
|
+
best_params = var.attrs['best_params'] = [p.clone() for p in var.params]
|
|
380
|
+
var.attrs['best_loss'] = loss
|
|
381
|
+
var.attrs['load_best_params'] = partial(_load_best_parameters, params=var.params, best_params=best_params)
|
|
382
|
+
|
|
383
|
+
return var
|