torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Module
|
|
6
|
+
from ...utils.tensorlist import Distributions
|
|
7
|
+
|
|
8
|
+
class PrintUpdate(Module):
|
|
9
|
+
def __init__(self, text = 'update = ', print_fn = print):
|
|
10
|
+
defaults = dict(text=text, print_fn=print_fn)
|
|
11
|
+
super().__init__(defaults)
|
|
12
|
+
|
|
13
|
+
def step(self, vars):
|
|
14
|
+
self.settings[vars.params[0]]["print_fn"](f'{self.settings[vars.params[0]]["text"]}{vars.update}')
|
|
15
|
+
return vars
|
|
16
|
+
|
|
17
|
+
class PrintShape(Module):
|
|
18
|
+
def __init__(self, text = 'shapes = ', print_fn = print):
|
|
19
|
+
defaults = dict(text=text, print_fn=print_fn)
|
|
20
|
+
super().__init__(defaults)
|
|
21
|
+
|
|
22
|
+
def step(self, vars):
|
|
23
|
+
shapes = [u.shape for u in vars.update] if vars.update is not None else None
|
|
24
|
+
self.settings[vars.params[0]]["print_fn"](f'{self.settings[vars.params[0]]["text"]}{shapes}')
|
|
25
|
+
return vars
|
|
@@ -0,0 +1,419 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from operator import itemgetter
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Chainable, Module, TensorwiseTransform, Target, Transform, Vars
|
|
9
|
+
from ...utils import Distributions, NumberList, TensorList
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Previous(TensorwiseTransform):
|
|
13
|
+
"""Maintains an update from n steps back, for example if n=1, returns previous update"""
|
|
14
|
+
def __init__(self, n=1, target: Target = 'update'):
|
|
15
|
+
defaults = dict(n=n)
|
|
16
|
+
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@torch.no_grad
|
|
20
|
+
def transform(self, tensor, param, grad, vars):
|
|
21
|
+
n = self.settings[param]['n']
|
|
22
|
+
state = self.state[param]
|
|
23
|
+
|
|
24
|
+
if 'history' not in state:
|
|
25
|
+
state['history'] = deque(maxlen=n+1)
|
|
26
|
+
|
|
27
|
+
state['history'].append(tensor)
|
|
28
|
+
|
|
29
|
+
return state['history'][0]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class LastDifference(Transform):
|
|
33
|
+
"""Difference between past two updates."""
|
|
34
|
+
def __init__(self,target: Target = 'update'):
|
|
35
|
+
super().__init__({}, uses_grad=False, target=target)
|
|
36
|
+
|
|
37
|
+
@torch.no_grad
|
|
38
|
+
def transform(self, tensors, params, grads, vars):
|
|
39
|
+
prev_target = self.get_state('prev_target', params=params) # initialized to 0
|
|
40
|
+
difference = torch._foreach_sub(tensors, prev_target)
|
|
41
|
+
for p, c in zip(prev_target, tensors): p.set_(c)
|
|
42
|
+
return difference
|
|
43
|
+
|
|
44
|
+
class LastGradDifference(Module):
|
|
45
|
+
"""Difference between past two grads."""
|
|
46
|
+
def __init__(self):
|
|
47
|
+
super().__init__({})
|
|
48
|
+
|
|
49
|
+
@torch.no_grad
|
|
50
|
+
def step(self, vars):
|
|
51
|
+
grad = vars.get_grad()
|
|
52
|
+
prev_grad = self.get_state('prev_grad', params=vars.params) # initialized to 0
|
|
53
|
+
difference = torch._foreach_sub(grad, prev_grad)
|
|
54
|
+
for p, c in zip(prev_grad, grad): p.set_(c)
|
|
55
|
+
vars.update = list(difference)
|
|
56
|
+
return vars
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class LastProduct(Transform):
|
|
60
|
+
"""Difference between past two updates."""
|
|
61
|
+
def __init__(self,target: Target = 'update'):
|
|
62
|
+
super().__init__({}, uses_grad=False, target=target)
|
|
63
|
+
|
|
64
|
+
@torch.no_grad
|
|
65
|
+
def transform(self, tensors, params, grads, vars):
|
|
66
|
+
prev_target = self.get_state('prev_target', params=params, init=torch.ones_like) # initialized to 1 for prod
|
|
67
|
+
prod = torch._foreach_mul(tensors, prev_target)
|
|
68
|
+
for p, c in zip(prev_target, tensors): p.set_(c)
|
|
69
|
+
return prod
|
|
70
|
+
|
|
71
|
+
class LastRatio(Transform):
|
|
72
|
+
"""Ratio between past two updates."""
|
|
73
|
+
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', target: Target = 'update'):
|
|
74
|
+
defaults = dict(numerator=numerator)
|
|
75
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
76
|
+
|
|
77
|
+
@torch.no_grad
|
|
78
|
+
def transform(self, tensors, params, grads, vars):
|
|
79
|
+
prev_target = self.get_state('prev_target', params=params, init = torch.ones_like) # initialized to ones
|
|
80
|
+
numerator = self.settings[params[0]]['numerator']
|
|
81
|
+
if numerator == 'cur': ratio = torch._foreach_div(tensors, prev_target)
|
|
82
|
+
else: ratio = torch._foreach_div(prev_target, tensors)
|
|
83
|
+
for p, c in zip(prev_target, tensors): p.set_(c)
|
|
84
|
+
return ratio
|
|
85
|
+
|
|
86
|
+
class LastAbsoluteRatio(Transform):
|
|
87
|
+
"""Ratio between absolute values of past two updates."""
|
|
88
|
+
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8, target: Target = 'update'):
|
|
89
|
+
defaults = dict(numerator=numerator, eps=eps)
|
|
90
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
91
|
+
|
|
92
|
+
@torch.no_grad
|
|
93
|
+
def transform(self, tensors, params, grads, vars):
|
|
94
|
+
prev_target = self.get_state('prev_target', params=params, init = torch.ones_like) # initialized to 0
|
|
95
|
+
numerator = self.settings[params[0]]['numerator']
|
|
96
|
+
eps = self.get_settings('eps', params=params, cls = NumberList)
|
|
97
|
+
|
|
98
|
+
torch._foreach_abs_(tensors)
|
|
99
|
+
torch._foreach_clamp_min_(prev_target, eps)
|
|
100
|
+
|
|
101
|
+
if numerator == 'cur': ratio = torch._foreach_div(tensors, prev_target)
|
|
102
|
+
else: ratio = torch._foreach_div(prev_target, tensors)
|
|
103
|
+
for p, c in zip(prev_target, tensors): p.set_(c)
|
|
104
|
+
return ratio
|
|
105
|
+
|
|
106
|
+
class GradSign(Transform):
|
|
107
|
+
"""copy gradient sign to update."""
|
|
108
|
+
def __init__(self, target: Target = 'update'):
|
|
109
|
+
super().__init__({}, uses_grad=True, target=target)
|
|
110
|
+
|
|
111
|
+
@torch.no_grad
|
|
112
|
+
def transform(self, tensors, params, grads, vars):
|
|
113
|
+
assert grads is not None
|
|
114
|
+
return [t.copysign_(g) for t,g in zip(tensors, grads)]
|
|
115
|
+
|
|
116
|
+
class UpdateSign(Transform):
|
|
117
|
+
"""use per-weight magnitudes from grad while using sign from update."""
|
|
118
|
+
def __init__(self, target: Target = 'update'):
|
|
119
|
+
super().__init__({}, uses_grad=True, target=target)
|
|
120
|
+
|
|
121
|
+
@torch.no_grad
|
|
122
|
+
def transform(self, tensors, params, grads, vars):
|
|
123
|
+
assert grads is not None
|
|
124
|
+
return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place
|
|
125
|
+
|
|
126
|
+
class GraftToGrad(Transform):
|
|
127
|
+
"""use gradient norm and update direction."""
|
|
128
|
+
def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
|
|
129
|
+
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
130
|
+
super().__init__(defaults, uses_grad=True, target=target)
|
|
131
|
+
|
|
132
|
+
@torch.no_grad
|
|
133
|
+
def transform(self, tensors, params, grads, vars):
|
|
134
|
+
assert grads is not None
|
|
135
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[params[0]])
|
|
136
|
+
return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
137
|
+
|
|
138
|
+
class GraftGradToUpdate(Transform):
|
|
139
|
+
"""use update norm and gradient direction."""
|
|
140
|
+
def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
|
|
141
|
+
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
142
|
+
super().__init__(defaults, uses_grad=True, target=target)
|
|
143
|
+
|
|
144
|
+
@torch.no_grad
|
|
145
|
+
def transform(self, tensors, params, grads, vars):
|
|
146
|
+
assert grads is not None
|
|
147
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[params[0]])
|
|
148
|
+
return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class GraftToParams(Transform):
|
|
152
|
+
"""makes update norm be set to parameter norm, but norm won't go below eps"""
|
|
153
|
+
def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-4, target: Target = 'update'):
|
|
154
|
+
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
155
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
156
|
+
|
|
157
|
+
@torch.no_grad
|
|
158
|
+
def transform(self, tensors, params, grads, vars):
|
|
159
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[params[0]])
|
|
160
|
+
return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
161
|
+
|
|
162
|
+
class Relative(Transform):
|
|
163
|
+
"""multiplies update by absolute parameter values to make it relative to their magnitude, min_value is minimum value to avoid getting stuck at 0"""
|
|
164
|
+
def __init__(self, min_value:float = 1e-4, target: Target = 'update'):
|
|
165
|
+
defaults = dict(min_value=min_value)
|
|
166
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
167
|
+
|
|
168
|
+
@torch.no_grad
|
|
169
|
+
def transform(self, tensors, params, grads, vars):
|
|
170
|
+
mul = TensorList(params).abs().clamp_(self.get_settings('min_value', params=params))
|
|
171
|
+
torch._foreach_mul_(tensors, mul)
|
|
172
|
+
return tensors
|
|
173
|
+
|
|
174
|
+
class FillLoss(Module):
|
|
175
|
+
"""makes tensors filled with loss value times alpha"""
|
|
176
|
+
def __init__(self, alpha: float = 1, backward: bool = True):
|
|
177
|
+
defaults = dict(alpha=alpha, backward=backward)
|
|
178
|
+
super().__init__(defaults)
|
|
179
|
+
|
|
180
|
+
@torch.no_grad
|
|
181
|
+
def step(self, vars):
|
|
182
|
+
alpha = self.get_settings('alpha', params=vars.params)
|
|
183
|
+
loss = vars.get_loss(backward=self.settings[vars.params[0]]['backward'])
|
|
184
|
+
vars.update = [torch.full_like(p, loss*a) for p,a in zip(vars.params, alpha)]
|
|
185
|
+
return vars
|
|
186
|
+
|
|
187
|
+
class MulByLoss(Transform):
|
|
188
|
+
"""multiplies update by loss times alpha"""
|
|
189
|
+
def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True, target: Target = 'update'):
|
|
190
|
+
defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
|
|
191
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
192
|
+
|
|
193
|
+
@torch.no_grad
|
|
194
|
+
def transform(self, tensors, params, grads, vars): #vars used for loss
|
|
195
|
+
alpha, min_value = self.get_settings('alpha', 'min_value', params=params)
|
|
196
|
+
loss = vars.get_loss(backward=self.settings[params[0]]['backward'])
|
|
197
|
+
mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
198
|
+
torch._foreach_mul_(tensors, mul)
|
|
199
|
+
return tensors
|
|
200
|
+
|
|
201
|
+
class DivByLoss(Transform):
|
|
202
|
+
"""divides update by loss times alpha"""
|
|
203
|
+
def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True, target: Target = 'update'):
|
|
204
|
+
defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
|
|
205
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
206
|
+
|
|
207
|
+
@torch.no_grad
|
|
208
|
+
def transform(self, tensors, params, grads, vars): #vars used for loss
|
|
209
|
+
alpha, min_value = self.get_settings('alpha', 'min_value', params=params)
|
|
210
|
+
loss = vars.get_loss(backward=self.settings[params[0]]['backward'])
|
|
211
|
+
mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
212
|
+
torch._foreach_div_(tensors, mul)
|
|
213
|
+
return tensors
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _sequential_step(self: Module, vars: Vars, sequential: bool):
|
|
218
|
+
params = vars.params
|
|
219
|
+
steps = self.settings[params[0]]['steps']
|
|
220
|
+
|
|
221
|
+
if sequential: modules = self.get_children_sequence()
|
|
222
|
+
else: modules = [self.children['module']] * steps
|
|
223
|
+
|
|
224
|
+
if vars.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
|
|
225
|
+
|
|
226
|
+
# store original params unless this is last module and can update params directly
|
|
227
|
+
params_before_steps = None if (vars.is_last and vars.last_module_lrs is None) else [p.clone() for p in params]
|
|
228
|
+
|
|
229
|
+
# first step - pass vars as usual
|
|
230
|
+
vars = modules[0].step(vars)
|
|
231
|
+
new_vars = vars
|
|
232
|
+
|
|
233
|
+
# subsequent steps - update parameters and create new vars
|
|
234
|
+
if len(modules) > 1:
|
|
235
|
+
for m in modules[1:]:
|
|
236
|
+
|
|
237
|
+
# update params
|
|
238
|
+
if (not new_vars.skip_update):
|
|
239
|
+
if new_vars.last_module_lrs is not None:
|
|
240
|
+
torch._foreach_mul_(new_vars.get_update(), new_vars.last_module_lrs)
|
|
241
|
+
|
|
242
|
+
torch._foreach_sub_(params, new_vars.get_update())
|
|
243
|
+
|
|
244
|
+
# create new vars since we are at a new point, that means grad, update and loss will be None
|
|
245
|
+
new_vars = Vars(params=new_vars.params, closure=new_vars.closure,
|
|
246
|
+
model=new_vars.model, current_step=new_vars.current_step + 1)
|
|
247
|
+
|
|
248
|
+
# step
|
|
249
|
+
new_vars = m.step(new_vars)
|
|
250
|
+
|
|
251
|
+
# final parameter update
|
|
252
|
+
if (not new_vars.skip_update):
|
|
253
|
+
if new_vars.last_module_lrs is not None:
|
|
254
|
+
torch._foreach_mul_(new_vars.get_update(), new_vars.last_module_lrs)
|
|
255
|
+
|
|
256
|
+
torch._foreach_sub_(params, new_vars.get_update())
|
|
257
|
+
|
|
258
|
+
# if last module, update is applied so return new vars
|
|
259
|
+
if params_before_steps is None:
|
|
260
|
+
new_vars.stop = True
|
|
261
|
+
new_vars.skip_update = True
|
|
262
|
+
return new_vars
|
|
263
|
+
|
|
264
|
+
# otherwise use parameter difference as update
|
|
265
|
+
vars.update = list(torch._foreach_sub(params_before_steps, params))
|
|
266
|
+
for p, bef in zip(params, params_before_steps):
|
|
267
|
+
p.set_(bef) # pyright:ignore[reportArgumentType]
|
|
268
|
+
return vars
|
|
269
|
+
|
|
270
|
+
class Multistep(Module):
|
|
271
|
+
def __init__(self, module: Chainable, steps: int):
|
|
272
|
+
defaults = dict(steps=steps)
|
|
273
|
+
super().__init__(defaults)
|
|
274
|
+
self.set_child('module', module)
|
|
275
|
+
|
|
276
|
+
@torch.no_grad
|
|
277
|
+
def step(self, vars):
|
|
278
|
+
return _sequential_step(self, vars, sequential=False)
|
|
279
|
+
|
|
280
|
+
class Sequential(Module):
|
|
281
|
+
def __init__(self, modules: Iterable[Chainable], steps: int):
|
|
282
|
+
defaults = dict(steps=steps)
|
|
283
|
+
super().__init__(defaults)
|
|
284
|
+
self.set_children_sequence(modules)
|
|
285
|
+
|
|
286
|
+
@torch.no_grad
|
|
287
|
+
def step(self, vars):
|
|
288
|
+
return _sequential_step(self, vars, sequential=True)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class GradientAccumulation(Module):
|
|
292
|
+
"""gradient accumulation"""
|
|
293
|
+
def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
|
|
294
|
+
defaults = dict(n=n, mean=mean, stop=stop)
|
|
295
|
+
super().__init__(defaults)
|
|
296
|
+
self.set_child('modules', modules)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
@torch.no_grad
|
|
300
|
+
def step(self, vars):
|
|
301
|
+
accumulator = self.get_state('accumulator', params=vars.params)
|
|
302
|
+
settings = self.settings[vars.params[0]]
|
|
303
|
+
n = settings['n']; mean = settings['mean']; stop = settings['stop']
|
|
304
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
305
|
+
|
|
306
|
+
# add update to accumulator
|
|
307
|
+
torch._foreach_add_(accumulator, vars.get_update())
|
|
308
|
+
|
|
309
|
+
# step with accumulated updates
|
|
310
|
+
if step % n == 0:
|
|
311
|
+
if mean:
|
|
312
|
+
torch._foreach_div_(accumulator, n)
|
|
313
|
+
|
|
314
|
+
vars.update = [a.clone() for a in accumulator]
|
|
315
|
+
vars = self.children['modules'].step(vars)
|
|
316
|
+
|
|
317
|
+
# zero accumulator
|
|
318
|
+
torch._foreach_zero_(accumulator)
|
|
319
|
+
|
|
320
|
+
else:
|
|
321
|
+
# prevent update
|
|
322
|
+
if stop:
|
|
323
|
+
vars.stop=True
|
|
324
|
+
vars.skip_update=True
|
|
325
|
+
|
|
326
|
+
return vars
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class Dropout(Transform):
|
|
330
|
+
def __init__(self, p: float = 0.5, graft: bool=False, target: Target = 'update'):
|
|
331
|
+
defaults = dict(p=p, graft=graft)
|
|
332
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
333
|
+
|
|
334
|
+
@torch.no_grad
|
|
335
|
+
def transform(self, tensors, params, grads, vars):
|
|
336
|
+
tensors = TensorList(tensors)
|
|
337
|
+
p = self.get_settings('p', params=params, cls=NumberList)
|
|
338
|
+
graft = self.settings[params[0]]['graft']
|
|
339
|
+
|
|
340
|
+
if graft:
|
|
341
|
+
target_norm = tensors.global_vector_norm()
|
|
342
|
+
tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
|
|
343
|
+
return tensors.mul_(target_norm / tensors.global_vector_norm()) # graft
|
|
344
|
+
|
|
345
|
+
return tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
|
|
346
|
+
|
|
347
|
+
class WeightDropout(Module):
|
|
348
|
+
"""Applies dropout directly to weights."""
|
|
349
|
+
def __init__(self, p: float = 0.5, graft: bool = True):
|
|
350
|
+
defaults = dict(p=p, graft=graft)
|
|
351
|
+
super().__init__(defaults)
|
|
352
|
+
|
|
353
|
+
@torch.no_grad
|
|
354
|
+
def step(self, vars):
|
|
355
|
+
closure = vars.closure
|
|
356
|
+
if closure is None: raise RuntimeError('WeightDropout requires closure')
|
|
357
|
+
params = TensorList(vars.params)
|
|
358
|
+
p = self.get_settings('p', params=params)
|
|
359
|
+
mask = params.rademacher_like(p).add_(1).div_(2).as_bool()
|
|
360
|
+
|
|
361
|
+
@torch.no_grad
|
|
362
|
+
def dropout_closure(backward=True):
|
|
363
|
+
orig_params = params.clone()
|
|
364
|
+
params.mul_(mask)
|
|
365
|
+
if backward:
|
|
366
|
+
with torch.enable_grad(): loss = closure()
|
|
367
|
+
else:
|
|
368
|
+
loss = closure(False)
|
|
369
|
+
params.copy_(orig_params)
|
|
370
|
+
return loss
|
|
371
|
+
|
|
372
|
+
vars.closure = dropout_closure
|
|
373
|
+
return vars
|
|
374
|
+
|
|
375
|
+
class NoiseSign(Transform):
|
|
376
|
+
"""uses random vector with update sign"""
|
|
377
|
+
def __init__(self, distribution:Distributions = 'normal', alpha = 1):
|
|
378
|
+
defaults = dict(distribution=distribution, alpha=alpha)
|
|
379
|
+
super().__init__(defaults, uses_grad=False)
|
|
380
|
+
|
|
381
|
+
@torch.no_grad
|
|
382
|
+
def transform(self, tensors, params, grads, vars):
|
|
383
|
+
alpha = self.get_settings('alpha', params=params)
|
|
384
|
+
distribution = self.settings[params[0]]['distribution']
|
|
385
|
+
return TensorList(tensors).sample_like(alpha, distribution).copysign_(tensors)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
class NegateOnLossIncrease(Module):
|
|
389
|
+
def __init__(self, backtrack=True):
|
|
390
|
+
defaults = dict(backtrack=backtrack)
|
|
391
|
+
super().__init__(defaults=defaults)
|
|
392
|
+
|
|
393
|
+
@torch.no_grad
|
|
394
|
+
def step(self, vars):
|
|
395
|
+
closure = vars.closure
|
|
396
|
+
if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
|
|
397
|
+
backtrack = self.settings[vars.params[0]]['backtrack']
|
|
398
|
+
|
|
399
|
+
update = vars.get_update()
|
|
400
|
+
f_0 = vars.get_loss(backward=False)
|
|
401
|
+
|
|
402
|
+
torch._foreach_sub_(vars.params, update)
|
|
403
|
+
f_1 = closure(False)
|
|
404
|
+
|
|
405
|
+
if f_1 <= f_0:
|
|
406
|
+
if vars.is_last and vars.last_module_lrs is None:
|
|
407
|
+
vars.stop = True
|
|
408
|
+
vars.skip_update = True
|
|
409
|
+
return vars
|
|
410
|
+
|
|
411
|
+
torch._foreach_add_(vars.params, update)
|
|
412
|
+
return vars
|
|
413
|
+
|
|
414
|
+
torch._foreach_add_(vars.params, update)
|
|
415
|
+
if backtrack:
|
|
416
|
+
torch._foreach_neg_(vars.update)
|
|
417
|
+
else:
|
|
418
|
+
torch._foreach_zero_(vars.update)
|
|
419
|
+
return vars
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
#pyright: reportIncompatibleMethodOverride=false
|
|
2
|
+
""""""
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import Iterable, Sequence
|
|
5
|
+
from operator import itemgetter
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ...core import Chainable, Module, Target, Vars, maybe_chain
|
|
11
|
+
from ...utils import TensorList, tensorlist
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MultiOperation(Module, ABC):
|
|
15
|
+
"""Base class for operations that use operands. This is an abstract class, subclass it and override `transform` method to use it."""
|
|
16
|
+
def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
|
|
17
|
+
super().__init__(defaults=defaults)
|
|
18
|
+
|
|
19
|
+
self.operands = {}
|
|
20
|
+
for k,v in operands.items():
|
|
21
|
+
|
|
22
|
+
if isinstance(v, (Module, Sequence)):
|
|
23
|
+
self.set_child(k, v)
|
|
24
|
+
self.operands[k] = self.children[k]
|
|
25
|
+
else:
|
|
26
|
+
self.operands[k] = v
|
|
27
|
+
|
|
28
|
+
if not self.children:
|
|
29
|
+
raise ValueError('At least one operand must be a module')
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def transform(self, vars: Vars, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
33
|
+
"""applies the operation to operands"""
|
|
34
|
+
raise NotImplementedError
|
|
35
|
+
|
|
36
|
+
@torch.no_grad
|
|
37
|
+
def step(self, vars: Vars) -> Vars:
|
|
38
|
+
# pass cloned update to all module operands
|
|
39
|
+
processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()
|
|
40
|
+
|
|
41
|
+
for k,v in self.operands.items():
|
|
42
|
+
if k in self.children:
|
|
43
|
+
v: Module
|
|
44
|
+
updated_vars = v.step(vars.clone(clone_update=True))
|
|
45
|
+
processed_operands[k] = updated_vars.get_update()
|
|
46
|
+
vars.update_attrs_from_clone_(updated_vars) # update loss, grad, etc if this module calculated them
|
|
47
|
+
|
|
48
|
+
transformed = self.transform(vars, **processed_operands)
|
|
49
|
+
vars.update = transformed
|
|
50
|
+
return vars
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class SubModules(MultiOperation):
|
|
55
|
+
def __init__(self, input: Chainable | float, other: Chainable | float, alpha: float = 1):
|
|
56
|
+
defaults = dict(alpha=alpha)
|
|
57
|
+
super().__init__(defaults, input=input, other=other)
|
|
58
|
+
|
|
59
|
+
@torch.no_grad
|
|
60
|
+
def transform(self, vars: Vars, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
61
|
+
alpha = self.settings[vars.params[0]]['alpha']
|
|
62
|
+
|
|
63
|
+
if isinstance(input, (int,float)):
|
|
64
|
+
assert isinstance(other, list)
|
|
65
|
+
return input - TensorList(other).mul_(alpha)
|
|
66
|
+
|
|
67
|
+
if isinstance(other, (int, float)): torch._foreach_sub_(input, other * alpha)
|
|
68
|
+
else: torch._foreach_sub_(input, other, alpha=alpha)
|
|
69
|
+
return input
|
|
70
|
+
|
|
71
|
+
class DivModules(MultiOperation):
|
|
72
|
+
def __init__(self, input: Chainable | float, other: Chainable | float):
|
|
73
|
+
defaults = {}
|
|
74
|
+
super().__init__(defaults, input=input, other=other)
|
|
75
|
+
|
|
76
|
+
@torch.no_grad
|
|
77
|
+
def transform(self, vars: Vars, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
78
|
+
if isinstance(input, (int,float)):
|
|
79
|
+
assert isinstance(other, list)
|
|
80
|
+
return input / TensorList(other)
|
|
81
|
+
|
|
82
|
+
torch._foreach_div_(input, other)
|
|
83
|
+
return input
|
|
84
|
+
|
|
85
|
+
class PowModules(MultiOperation):
|
|
86
|
+
def __init__(self, input: Chainable | float, exponent: Chainable | float):
|
|
87
|
+
defaults = {}
|
|
88
|
+
super().__init__(defaults, input=input, exponent=exponent)
|
|
89
|
+
|
|
90
|
+
@torch.no_grad
|
|
91
|
+
def transform(self, vars: Vars, input: float | list[torch.Tensor], exponent: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
92
|
+
if isinstance(input, (int,float)):
|
|
93
|
+
assert isinstance(exponent, list)
|
|
94
|
+
return input ** TensorList(exponent)
|
|
95
|
+
|
|
96
|
+
torch._foreach_div_(input, exponent)
|
|
97
|
+
return input
|
|
98
|
+
|
|
99
|
+
class LerpModules(MultiOperation):
|
|
100
|
+
def __init__(self, input: Chainable, end: Chainable, weight: float):
|
|
101
|
+
defaults = dict(weight=weight)
|
|
102
|
+
super().__init__(defaults, input=input, end=end)
|
|
103
|
+
|
|
104
|
+
@torch.no_grad
|
|
105
|
+
def transform(self, vars: Vars, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
106
|
+
torch._foreach_lerp_(input, end, weight=self.settings[vars.params[0]]['weight'])
|
|
107
|
+
return input
|
|
108
|
+
|
|
109
|
+
class ClipModules(MultiOperation):
|
|
110
|
+
def __init__(self, input: Chainable, min: float | Chainable | None = None, max: float | Chainable | None = None):
|
|
111
|
+
defaults = {}
|
|
112
|
+
super().__init__(defaults, input=input, min=min, max=max)
|
|
113
|
+
|
|
114
|
+
@torch.no_grad
|
|
115
|
+
def transform(self, vars: Vars, input: list[torch.Tensor], min: float | list[torch.Tensor], max: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
116
|
+
return TensorList(input).clamp_(min=min, max=max)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class GraftModules(MultiOperation):
|
|
120
|
+
def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6, strength:float=1):
|
|
121
|
+
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
|
|
122
|
+
super().__init__(defaults, direction=direction, magnitude=magnitude)
|
|
123
|
+
|
|
124
|
+
@torch.no_grad
|
|
125
|
+
def transform(self, vars, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
|
|
126
|
+
tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.settings[vars.params[0]])
|
|
127
|
+
return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class Where(MultiOperation):
|
|
131
|
+
def __init__(self, condition: Chainable, input: Chainable | float, other: Chainable | float):
|
|
132
|
+
super().__init__({}, condition=condition, input=input, other=other)
|
|
133
|
+
|
|
134
|
+
@torch.no_grad
|
|
135
|
+
def transform(self, vars, condition: list[torch.Tensor], input: list[torch.Tensor] | float, other: list[torch.Tensor] | float):
|
|
136
|
+
return tensorlist.where(TensorList(condition).as_bool(), input, other)
|
|
137
|
+
|