torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +54 -21
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +19 -129
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +12 -12
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +67 -17
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +12 -12
- torchzero/modules/quasi_newton/lsr1.py +11 -11
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +254 -47
- torchzero/modules/second_order/newton.py +32 -20
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +21 -21
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
torchzero/modules/ops/misc.py
CHANGED
|
@@ -5,8 +5,8 @@ from typing import Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Chainable, Module,
|
|
9
|
-
from ...utils import Distributions, NumberList, TensorList
|
|
8
|
+
from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
|
|
9
|
+
from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class Previous(TensorwiseTransform):
|
|
@@ -17,9 +17,8 @@ class Previous(TensorwiseTransform):
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
@torch.no_grad
|
|
20
|
-
def
|
|
21
|
-
n =
|
|
22
|
-
state = self.state[param]
|
|
20
|
+
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
21
|
+
n = settings['n']
|
|
23
22
|
|
|
24
23
|
if 'history' not in state:
|
|
25
24
|
state['history'] = deque(maxlen=n+1)
|
|
@@ -35,10 +34,10 @@ class LastDifference(Transform):
|
|
|
35
34
|
super().__init__({}, uses_grad=False, target=target)
|
|
36
35
|
|
|
37
36
|
@torch.no_grad
|
|
38
|
-
def
|
|
39
|
-
|
|
40
|
-
difference = torch._foreach_sub(tensors,
|
|
41
|
-
for p, c in zip(
|
|
37
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
38
|
+
prev = unpack_states(states, tensors, 'prev_target') # initialized to 0
|
|
39
|
+
difference = torch._foreach_sub(tensors, prev)
|
|
40
|
+
for p, c in zip(prev, tensors): p.set_(c)
|
|
42
41
|
return difference
|
|
43
42
|
|
|
44
43
|
class LastGradDifference(Module):
|
|
@@ -47,13 +46,13 @@ class LastGradDifference(Module):
|
|
|
47
46
|
super().__init__({})
|
|
48
47
|
|
|
49
48
|
@torch.no_grad
|
|
50
|
-
def step(self,
|
|
51
|
-
grad =
|
|
52
|
-
prev_grad = self.get_state('prev_grad'
|
|
49
|
+
def step(self, var):
|
|
50
|
+
grad = var.get_grad()
|
|
51
|
+
prev_grad = self.get_state(var.params, 'prev_grad') # initialized to 0
|
|
53
52
|
difference = torch._foreach_sub(grad, prev_grad)
|
|
54
53
|
for p, c in zip(prev_grad, grad): p.set_(c)
|
|
55
|
-
|
|
56
|
-
return
|
|
54
|
+
var.update = list(difference)
|
|
55
|
+
return var
|
|
57
56
|
|
|
58
57
|
|
|
59
58
|
class LastProduct(Transform):
|
|
@@ -62,10 +61,10 @@ class LastProduct(Transform):
|
|
|
62
61
|
super().__init__({}, uses_grad=False, target=target)
|
|
63
62
|
|
|
64
63
|
@torch.no_grad
|
|
65
|
-
def
|
|
66
|
-
|
|
67
|
-
prod = torch._foreach_mul(tensors,
|
|
68
|
-
for p, c in zip(
|
|
64
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
65
|
+
prev = unpack_states(states, tensors, 'prev', init=torch.ones_like) # initialized to 1 for prod
|
|
66
|
+
prod = torch._foreach_mul(tensors, prev)
|
|
67
|
+
for p, c in zip(prev, tensors): p.set_(c)
|
|
69
68
|
return prod
|
|
70
69
|
|
|
71
70
|
class LastRatio(Transform):
|
|
@@ -75,12 +74,12 @@ class LastRatio(Transform):
|
|
|
75
74
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
76
75
|
|
|
77
76
|
@torch.no_grad
|
|
78
|
-
def
|
|
79
|
-
|
|
80
|
-
numerator =
|
|
81
|
-
if numerator == 'cur': ratio = torch._foreach_div(tensors,
|
|
82
|
-
else: ratio = torch._foreach_div(
|
|
83
|
-
for p, c in zip(
|
|
77
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
78
|
+
prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
|
|
79
|
+
numerator = settings[0]['numerator']
|
|
80
|
+
if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
|
|
81
|
+
else: ratio = torch._foreach_div(prev, tensors)
|
|
82
|
+
for p, c in zip(prev, tensors): p.set_(c)
|
|
84
83
|
return ratio
|
|
85
84
|
|
|
86
85
|
class LastAbsoluteRatio(Transform):
|
|
@@ -90,17 +89,17 @@ class LastAbsoluteRatio(Transform):
|
|
|
90
89
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
91
90
|
|
|
92
91
|
@torch.no_grad
|
|
93
|
-
def
|
|
94
|
-
|
|
95
|
-
numerator =
|
|
96
|
-
eps =
|
|
92
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
93
|
+
prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
|
|
94
|
+
numerator = settings[0]['numerator']
|
|
95
|
+
eps = NumberList(s['eps'] for s in settings)
|
|
97
96
|
|
|
98
97
|
torch._foreach_abs_(tensors)
|
|
99
|
-
torch._foreach_clamp_min_(
|
|
98
|
+
torch._foreach_clamp_min_(prev, eps)
|
|
100
99
|
|
|
101
|
-
if numerator == 'cur': ratio = torch._foreach_div(tensors,
|
|
102
|
-
else: ratio = torch._foreach_div(
|
|
103
|
-
for p, c in zip(
|
|
100
|
+
if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
|
|
101
|
+
else: ratio = torch._foreach_div(prev, tensors)
|
|
102
|
+
for p, c in zip(prev, tensors): p.set_(c)
|
|
104
103
|
return ratio
|
|
105
104
|
|
|
106
105
|
class GradSign(Transform):
|
|
@@ -109,7 +108,7 @@ class GradSign(Transform):
|
|
|
109
108
|
super().__init__({}, uses_grad=True, target=target)
|
|
110
109
|
|
|
111
110
|
@torch.no_grad
|
|
112
|
-
def
|
|
111
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
113
112
|
assert grads is not None
|
|
114
113
|
return [t.copysign_(g) for t,g in zip(tensors, grads)]
|
|
115
114
|
|
|
@@ -119,7 +118,7 @@ class UpdateSign(Transform):
|
|
|
119
118
|
super().__init__({}, uses_grad=True, target=target)
|
|
120
119
|
|
|
121
120
|
@torch.no_grad
|
|
122
|
-
def
|
|
121
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
123
122
|
assert grads is not None
|
|
124
123
|
return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place
|
|
125
124
|
|
|
@@ -130,9 +129,9 @@ class GraftToGrad(Transform):
|
|
|
130
129
|
super().__init__(defaults, uses_grad=True, target=target)
|
|
131
130
|
|
|
132
131
|
@torch.no_grad
|
|
133
|
-
def
|
|
132
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
134
133
|
assert grads is not None
|
|
135
|
-
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(
|
|
134
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
136
135
|
return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
137
136
|
|
|
138
137
|
class GraftGradToUpdate(Transform):
|
|
@@ -142,9 +141,9 @@ class GraftGradToUpdate(Transform):
|
|
|
142
141
|
super().__init__(defaults, uses_grad=True, target=target)
|
|
143
142
|
|
|
144
143
|
@torch.no_grad
|
|
145
|
-
def
|
|
144
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
146
145
|
assert grads is not None
|
|
147
|
-
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(
|
|
146
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
148
147
|
return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
149
148
|
|
|
150
149
|
|
|
@@ -155,8 +154,8 @@ class GraftToParams(Transform):
|
|
|
155
154
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
156
155
|
|
|
157
156
|
@torch.no_grad
|
|
158
|
-
def
|
|
159
|
-
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(
|
|
157
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
158
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
160
159
|
return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
161
160
|
|
|
162
161
|
class Relative(Transform):
|
|
@@ -166,8 +165,8 @@ class Relative(Transform):
|
|
|
166
165
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
167
166
|
|
|
168
167
|
@torch.no_grad
|
|
169
|
-
def
|
|
170
|
-
mul = TensorList(params).abs().clamp_(
|
|
168
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
169
|
+
mul = TensorList(params).abs().clamp_([s['min_value'] for s in settings])
|
|
171
170
|
torch._foreach_mul_(tensors, mul)
|
|
172
171
|
return tensors
|
|
173
172
|
|
|
@@ -178,94 +177,94 @@ class FillLoss(Module):
|
|
|
178
177
|
super().__init__(defaults)
|
|
179
178
|
|
|
180
179
|
@torch.no_grad
|
|
181
|
-
def step(self,
|
|
182
|
-
alpha = self.get_settings('alpha'
|
|
183
|
-
loss =
|
|
184
|
-
|
|
185
|
-
return
|
|
180
|
+
def step(self, var):
|
|
181
|
+
alpha = self.get_settings(var.params, 'alpha')
|
|
182
|
+
loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
|
|
183
|
+
var.update = [torch.full_like(p, loss*a) for p,a in zip(var.params, alpha)]
|
|
184
|
+
return var
|
|
186
185
|
|
|
187
|
-
class MulByLoss(
|
|
186
|
+
class MulByLoss(Module):
|
|
188
187
|
"""multiplies update by loss times alpha"""
|
|
189
|
-
def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True
|
|
188
|
+
def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
|
|
190
189
|
defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
|
|
191
|
-
super().__init__(defaults
|
|
190
|
+
super().__init__(defaults)
|
|
192
191
|
|
|
193
192
|
@torch.no_grad
|
|
194
|
-
def
|
|
195
|
-
alpha, min_value = self.get_settings('alpha', 'min_value'
|
|
196
|
-
loss =
|
|
193
|
+
def step(self, var):
|
|
194
|
+
alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
|
|
195
|
+
loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
|
|
197
196
|
mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
198
|
-
torch._foreach_mul_(
|
|
199
|
-
return
|
|
197
|
+
torch._foreach_mul_(var.update, mul)
|
|
198
|
+
return var
|
|
200
199
|
|
|
201
|
-
class DivByLoss(
|
|
200
|
+
class DivByLoss(Module):
|
|
202
201
|
"""divides update by loss times alpha"""
|
|
203
|
-
def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True
|
|
202
|
+
def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
|
|
204
203
|
defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
|
|
205
|
-
super().__init__(defaults
|
|
204
|
+
super().__init__(defaults)
|
|
206
205
|
|
|
207
206
|
@torch.no_grad
|
|
208
|
-
def
|
|
209
|
-
alpha, min_value = self.get_settings('alpha', 'min_value'
|
|
210
|
-
loss =
|
|
207
|
+
def step(self, var):
|
|
208
|
+
alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
|
|
209
|
+
loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
|
|
211
210
|
mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
212
|
-
torch._foreach_div_(
|
|
213
|
-
return
|
|
211
|
+
torch._foreach_div_(var.update, mul)
|
|
212
|
+
return var
|
|
214
213
|
|
|
215
214
|
|
|
216
215
|
|
|
217
|
-
def _sequential_step(self: Module,
|
|
218
|
-
params =
|
|
216
|
+
def _sequential_step(self: Module, var: Var, sequential: bool):
|
|
217
|
+
params = var.params
|
|
219
218
|
steps = self.settings[params[0]]['steps']
|
|
220
219
|
|
|
221
220
|
if sequential: modules = self.get_children_sequence()
|
|
222
221
|
else: modules = [self.children['module']] * steps
|
|
223
222
|
|
|
224
|
-
if
|
|
223
|
+
if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
|
|
225
224
|
|
|
226
225
|
# store original params unless this is last module and can update params directly
|
|
227
|
-
params_before_steps = None if (
|
|
226
|
+
params_before_steps = None if (var.is_last and var.last_module_lrs is None) else [p.clone() for p in params]
|
|
228
227
|
|
|
229
|
-
# first step - pass
|
|
230
|
-
|
|
231
|
-
|
|
228
|
+
# first step - pass var as usual
|
|
229
|
+
var = modules[0].step(var)
|
|
230
|
+
new_var = var
|
|
232
231
|
|
|
233
|
-
# subsequent steps - update parameters and create new
|
|
232
|
+
# subsequent steps - update parameters and create new var
|
|
234
233
|
if len(modules) > 1:
|
|
235
234
|
for m in modules[1:]:
|
|
236
235
|
|
|
237
236
|
# update params
|
|
238
|
-
if (not
|
|
239
|
-
if
|
|
240
|
-
torch._foreach_mul_(
|
|
237
|
+
if (not new_var.skip_update):
|
|
238
|
+
if new_var.last_module_lrs is not None:
|
|
239
|
+
torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
241
240
|
|
|
242
|
-
torch._foreach_sub_(params,
|
|
241
|
+
torch._foreach_sub_(params, new_var.get_update())
|
|
243
242
|
|
|
244
|
-
# create new
|
|
245
|
-
|
|
246
|
-
model=
|
|
243
|
+
# create new var since we are at a new point, that means grad, update and loss will be None
|
|
244
|
+
new_var = Var(params=new_var.params, closure=new_var.closure,
|
|
245
|
+
model=new_var.model, current_step=new_var.current_step + 1)
|
|
247
246
|
|
|
248
247
|
# step
|
|
249
|
-
|
|
248
|
+
new_var = m.step(new_var)
|
|
250
249
|
|
|
251
250
|
# final parameter update
|
|
252
|
-
if (not
|
|
253
|
-
if
|
|
254
|
-
torch._foreach_mul_(
|
|
251
|
+
if (not new_var.skip_update):
|
|
252
|
+
if new_var.last_module_lrs is not None:
|
|
253
|
+
torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
255
254
|
|
|
256
|
-
torch._foreach_sub_(params,
|
|
255
|
+
torch._foreach_sub_(params, new_var.get_update())
|
|
257
256
|
|
|
258
|
-
# if last module, update is applied so return new
|
|
257
|
+
# if last module, update is applied so return new var
|
|
259
258
|
if params_before_steps is None:
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
return
|
|
259
|
+
new_var.stop = True
|
|
260
|
+
new_var.skip_update = True
|
|
261
|
+
return new_var
|
|
263
262
|
|
|
264
263
|
# otherwise use parameter difference as update
|
|
265
|
-
|
|
264
|
+
var.update = list(torch._foreach_sub(params_before_steps, params))
|
|
266
265
|
for p, bef in zip(params, params_before_steps):
|
|
267
266
|
p.set_(bef) # pyright:ignore[reportArgumentType]
|
|
268
|
-
return
|
|
267
|
+
return var
|
|
269
268
|
|
|
270
269
|
class Multistep(Module):
|
|
271
270
|
def __init__(self, module: Chainable, steps: int):
|
|
@@ -274,8 +273,8 @@ class Multistep(Module):
|
|
|
274
273
|
self.set_child('module', module)
|
|
275
274
|
|
|
276
275
|
@torch.no_grad
|
|
277
|
-
def step(self,
|
|
278
|
-
return _sequential_step(self,
|
|
276
|
+
def step(self, var):
|
|
277
|
+
return _sequential_step(self, var, sequential=False)
|
|
279
278
|
|
|
280
279
|
class Sequential(Module):
|
|
281
280
|
def __init__(self, modules: Iterable[Chainable], steps: int):
|
|
@@ -284,8 +283,8 @@ class Sequential(Module):
|
|
|
284
283
|
self.set_children_sequence(modules)
|
|
285
284
|
|
|
286
285
|
@torch.no_grad
|
|
287
|
-
def step(self,
|
|
288
|
-
return _sequential_step(self,
|
|
286
|
+
def step(self, var):
|
|
287
|
+
return _sequential_step(self, var, sequential=True)
|
|
289
288
|
|
|
290
289
|
|
|
291
290
|
class GradientAccumulation(Module):
|
|
@@ -297,22 +296,22 @@ class GradientAccumulation(Module):
|
|
|
297
296
|
|
|
298
297
|
|
|
299
298
|
@torch.no_grad
|
|
300
|
-
def step(self,
|
|
301
|
-
accumulator = self.get_state('accumulator'
|
|
302
|
-
settings = self.settings[
|
|
299
|
+
def step(self, var):
|
|
300
|
+
accumulator = self.get_state(var.params, 'accumulator')
|
|
301
|
+
settings = self.settings[var.params[0]]
|
|
303
302
|
n = settings['n']; mean = settings['mean']; stop = settings['stop']
|
|
304
303
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
305
304
|
|
|
306
305
|
# add update to accumulator
|
|
307
|
-
torch._foreach_add_(accumulator,
|
|
306
|
+
torch._foreach_add_(accumulator, var.get_update())
|
|
308
307
|
|
|
309
308
|
# step with accumulated updates
|
|
310
309
|
if step % n == 0:
|
|
311
310
|
if mean:
|
|
312
311
|
torch._foreach_div_(accumulator, n)
|
|
313
312
|
|
|
314
|
-
|
|
315
|
-
|
|
313
|
+
var.update = [a.clone() for a in accumulator]
|
|
314
|
+
var = self.children['modules'].step(var)
|
|
316
315
|
|
|
317
316
|
# zero accumulator
|
|
318
317
|
torch._foreach_zero_(accumulator)
|
|
@@ -320,10 +319,10 @@ class GradientAccumulation(Module):
|
|
|
320
319
|
else:
|
|
321
320
|
# prevent update
|
|
322
321
|
if stop:
|
|
323
|
-
|
|
324
|
-
|
|
322
|
+
var.stop=True
|
|
323
|
+
var.skip_update=True
|
|
325
324
|
|
|
326
|
-
return
|
|
325
|
+
return var
|
|
327
326
|
|
|
328
327
|
|
|
329
328
|
class Dropout(Transform):
|
|
@@ -332,10 +331,10 @@ class Dropout(Transform):
|
|
|
332
331
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
333
332
|
|
|
334
333
|
@torch.no_grad
|
|
335
|
-
def
|
|
334
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
336
335
|
tensors = TensorList(tensors)
|
|
337
|
-
p =
|
|
338
|
-
graft =
|
|
336
|
+
p = NumberList(s['p'] for s in settings)
|
|
337
|
+
graft = settings[0]['graft']
|
|
339
338
|
|
|
340
339
|
if graft:
|
|
341
340
|
target_norm = tensors.global_vector_norm()
|
|
@@ -351,11 +350,11 @@ class WeightDropout(Module):
|
|
|
351
350
|
super().__init__(defaults)
|
|
352
351
|
|
|
353
352
|
@torch.no_grad
|
|
354
|
-
def step(self,
|
|
355
|
-
closure =
|
|
353
|
+
def step(self, var):
|
|
354
|
+
closure = var.closure
|
|
356
355
|
if closure is None: raise RuntimeError('WeightDropout requires closure')
|
|
357
|
-
params = TensorList(
|
|
358
|
-
p = self.
|
|
356
|
+
params = TensorList(var.params)
|
|
357
|
+
p = NumberList(self.settings[p]['p'] for p in params)
|
|
359
358
|
mask = params.rademacher_like(p).add_(1).div_(2).as_bool()
|
|
360
359
|
|
|
361
360
|
@torch.no_grad
|
|
@@ -369,8 +368,8 @@ class WeightDropout(Module):
|
|
|
369
368
|
params.copy_(orig_params)
|
|
370
369
|
return loss
|
|
371
370
|
|
|
372
|
-
|
|
373
|
-
return
|
|
371
|
+
var.closure = dropout_closure
|
|
372
|
+
return var
|
|
374
373
|
|
|
375
374
|
class NoiseSign(Transform):
|
|
376
375
|
"""uses random vector with update sign"""
|
|
@@ -379,8 +378,8 @@ class NoiseSign(Transform):
|
|
|
379
378
|
super().__init__(defaults, uses_grad=False)
|
|
380
379
|
|
|
381
380
|
@torch.no_grad
|
|
382
|
-
def
|
|
383
|
-
alpha =
|
|
381
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
382
|
+
alpha = [s['alpha'] for s in settings]
|
|
384
383
|
distribution = self.settings[params[0]]['distribution']
|
|
385
384
|
return TensorList(tensors).sample_like(alpha, distribution).copysign_(tensors)
|
|
386
385
|
|
|
@@ -391,29 +390,29 @@ class NegateOnLossIncrease(Module):
|
|
|
391
390
|
super().__init__(defaults=defaults)
|
|
392
391
|
|
|
393
392
|
@torch.no_grad
|
|
394
|
-
def step(self,
|
|
395
|
-
closure =
|
|
393
|
+
def step(self, var):
|
|
394
|
+
closure = var.closure
|
|
396
395
|
if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
|
|
397
|
-
backtrack = self.settings[
|
|
396
|
+
backtrack = self.settings[var.params[0]]['backtrack']
|
|
398
397
|
|
|
399
|
-
update =
|
|
400
|
-
f_0 =
|
|
398
|
+
update = var.get_update()
|
|
399
|
+
f_0 = var.get_loss(backward=False)
|
|
401
400
|
|
|
402
|
-
torch._foreach_sub_(
|
|
401
|
+
torch._foreach_sub_(var.params, update)
|
|
403
402
|
f_1 = closure(False)
|
|
404
403
|
|
|
405
404
|
if f_1 <= f_0:
|
|
406
|
-
if
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
return
|
|
405
|
+
if var.is_last and var.last_module_lrs is None:
|
|
406
|
+
var.stop = True
|
|
407
|
+
var.skip_update = True
|
|
408
|
+
return var
|
|
410
409
|
|
|
411
|
-
torch._foreach_add_(
|
|
412
|
-
return
|
|
410
|
+
torch._foreach_add_(var.params, update)
|
|
411
|
+
return var
|
|
413
412
|
|
|
414
|
-
torch._foreach_add_(
|
|
413
|
+
torch._foreach_add_(var.params, update)
|
|
415
414
|
if backtrack:
|
|
416
|
-
torch._foreach_neg_(
|
|
415
|
+
torch._foreach_neg_(var.update)
|
|
417
416
|
else:
|
|
418
|
-
torch._foreach_zero_(
|
|
419
|
-
return
|
|
417
|
+
torch._foreach_zero_(var.update)
|
|
418
|
+
return var
|
torchzero/modules/ops/multi.py
CHANGED
|
@@ -7,7 +7,7 @@ from typing import Any
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
from ...core import Chainable, Module, Target,
|
|
10
|
+
from ...core import Chainable, Module, Target, Var, maybe_chain
|
|
11
11
|
from ...utils import TensorList, tensorlist
|
|
12
12
|
|
|
13
13
|
|
|
@@ -29,25 +29,25 @@ class MultiOperation(Module, ABC):
|
|
|
29
29
|
raise ValueError('At least one operand must be a module')
|
|
30
30
|
|
|
31
31
|
@abstractmethod
|
|
32
|
-
def transform(self,
|
|
32
|
+
def transform(self, var: Var, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
33
33
|
"""applies the operation to operands"""
|
|
34
34
|
raise NotImplementedError
|
|
35
35
|
|
|
36
36
|
@torch.no_grad
|
|
37
|
-
def step(self,
|
|
37
|
+
def step(self, var: Var) -> Var:
|
|
38
38
|
# pass cloned update to all module operands
|
|
39
39
|
processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()
|
|
40
40
|
|
|
41
41
|
for k,v in self.operands.items():
|
|
42
42
|
if k in self.children:
|
|
43
43
|
v: Module
|
|
44
|
-
|
|
45
|
-
processed_operands[k] =
|
|
46
|
-
|
|
44
|
+
updated_var = v.step(var.clone(clone_update=True))
|
|
45
|
+
processed_operands[k] = updated_var.get_update()
|
|
46
|
+
var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them
|
|
47
47
|
|
|
48
|
-
transformed = self.transform(
|
|
49
|
-
|
|
50
|
-
return
|
|
48
|
+
transformed = self.transform(var, **processed_operands)
|
|
49
|
+
var.update = transformed
|
|
50
|
+
return var
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
|
|
@@ -57,8 +57,8 @@ class SubModules(MultiOperation):
|
|
|
57
57
|
super().__init__(defaults, input=input, other=other)
|
|
58
58
|
|
|
59
59
|
@torch.no_grad
|
|
60
|
-
def transform(self,
|
|
61
|
-
alpha = self.settings[
|
|
60
|
+
def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
61
|
+
alpha = self.settings[var.params[0]]['alpha']
|
|
62
62
|
|
|
63
63
|
if isinstance(input, (int,float)):
|
|
64
64
|
assert isinstance(other, list)
|
|
@@ -74,7 +74,7 @@ class DivModules(MultiOperation):
|
|
|
74
74
|
super().__init__(defaults, input=input, other=other)
|
|
75
75
|
|
|
76
76
|
@torch.no_grad
|
|
77
|
-
def transform(self,
|
|
77
|
+
def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
78
78
|
if isinstance(input, (int,float)):
|
|
79
79
|
assert isinstance(other, list)
|
|
80
80
|
return input / TensorList(other)
|
|
@@ -88,7 +88,7 @@ class PowModules(MultiOperation):
|
|
|
88
88
|
super().__init__(defaults, input=input, exponent=exponent)
|
|
89
89
|
|
|
90
90
|
@torch.no_grad
|
|
91
|
-
def transform(self,
|
|
91
|
+
def transform(self, var: Var, input: float | list[torch.Tensor], exponent: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
92
92
|
if isinstance(input, (int,float)):
|
|
93
93
|
assert isinstance(exponent, list)
|
|
94
94
|
return input ** TensorList(exponent)
|
|
@@ -102,8 +102,8 @@ class LerpModules(MultiOperation):
|
|
|
102
102
|
super().__init__(defaults, input=input, end=end)
|
|
103
103
|
|
|
104
104
|
@torch.no_grad
|
|
105
|
-
def transform(self,
|
|
106
|
-
torch._foreach_lerp_(input, end, weight=self.settings[
|
|
105
|
+
def transform(self, var: Var, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
106
|
+
torch._foreach_lerp_(input, end, weight=self.settings[var.params[0]]['weight'])
|
|
107
107
|
return input
|
|
108
108
|
|
|
109
109
|
class ClipModules(MultiOperation):
|
|
@@ -112,7 +112,7 @@ class ClipModules(MultiOperation):
|
|
|
112
112
|
super().__init__(defaults, input=input, min=min, max=max)
|
|
113
113
|
|
|
114
114
|
@torch.no_grad
|
|
115
|
-
def transform(self,
|
|
115
|
+
def transform(self, var: Var, input: list[torch.Tensor], min: float | list[torch.Tensor], max: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
116
116
|
return TensorList(input).clamp_(min=min, max=max)
|
|
117
117
|
|
|
118
118
|
|
|
@@ -122,8 +122,8 @@ class GraftModules(MultiOperation):
|
|
|
122
122
|
super().__init__(defaults, direction=direction, magnitude=magnitude)
|
|
123
123
|
|
|
124
124
|
@torch.no_grad
|
|
125
|
-
def transform(self,
|
|
126
|
-
tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.settings[
|
|
125
|
+
def transform(self, var, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
|
|
126
|
+
tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.settings[var.params[0]])
|
|
127
127
|
return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
|
|
128
128
|
|
|
129
129
|
|
|
@@ -132,6 +132,6 @@ class Where(MultiOperation):
|
|
|
132
132
|
super().__init__({}, condition=condition, input=input, other=other)
|
|
133
133
|
|
|
134
134
|
@torch.no_grad
|
|
135
|
-
def transform(self,
|
|
135
|
+
def transform(self, var, condition: list[torch.Tensor], input: list[torch.Tensor] | float, other: list[torch.Tensor] | float):
|
|
136
136
|
return tensorlist.where(TensorList(condition).as_bool(), input, other)
|
|
137
137
|
|