torchzero 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchzero/core/__init__.py +1 -1
- torchzero/core/module.py +72 -49
- torchzero/core/tensorlist_optimizer.py +1 -1
- torchzero/modules/adaptive/adaptive.py +11 -11
- torchzero/modules/experimental/experimental.py +41 -41
- torchzero/modules/experimental/quad_interp.py +8 -8
- torchzero/modules/experimental/subspace.py +37 -37
- torchzero/modules/gradient_approximation/base_approximator.py +19 -24
- torchzero/modules/gradient_approximation/fdm.py +1 -1
- torchzero/modules/gradient_approximation/newton_fdm.py +13 -13
- torchzero/modules/gradient_approximation/rfdm.py +1 -1
- torchzero/modules/line_search/armijo.py +8 -8
- torchzero/modules/line_search/base_ls.py +8 -8
- torchzero/modules/line_search/directional_newton.py +14 -14
- torchzero/modules/line_search/grid_ls.py +7 -7
- torchzero/modules/line_search/scipy_minimize_scalar.py +3 -3
- torchzero/modules/meta/alternate.py +4 -4
- torchzero/modules/meta/grafting.py +23 -23
- torchzero/modules/meta/optimizer_wrapper.py +14 -14
- torchzero/modules/meta/return_overrides.py +8 -8
- torchzero/modules/misc/accumulate.py +6 -6
- torchzero/modules/misc/basic.py +16 -16
- torchzero/modules/misc/lr.py +2 -2
- torchzero/modules/misc/multistep.py +7 -7
- torchzero/modules/misc/on_increase.py +9 -9
- torchzero/modules/momentum/momentum.py +4 -4
- torchzero/modules/operations/multi.py +44 -44
- torchzero/modules/operations/reduction.py +28 -28
- torchzero/modules/operations/singular.py +9 -9
- torchzero/modules/optimizers/adagrad.py +1 -1
- torchzero/modules/optimizers/adam.py +8 -8
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +1 -1
- torchzero/modules/optimizers/rprop.py +1 -1
- torchzero/modules/optimizers/sgd.py +2 -2
- torchzero/modules/orthogonalization/newtonschulz.py +3 -3
- torchzero/modules/orthogonalization/svd.py +1 -1
- torchzero/modules/regularization/dropout.py +1 -1
- torchzero/modules/regularization/noise.py +3 -3
- torchzero/modules/regularization/normalization.py +5 -5
- torchzero/modules/regularization/ortho_grad.py +1 -1
- torchzero/modules/regularization/weight_decay.py +1 -1
- torchzero/modules/scheduling/lr_schedulers.py +2 -2
- torchzero/modules/scheduling/step_size.py +8 -8
- torchzero/modules/second_order/newton.py +12 -12
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/gaussian_smoothing.py +7 -7
- torchzero/modules/smoothing/laplacian_smoothing.py +1 -1
- torchzero/modules/weight_averaging/ema.py +3 -3
- torchzero/modules/weight_averaging/swa.py +8 -8
- torchzero/optim/first_order/forward_gradient.py +1 -1
- torchzero/optim/modular.py +4 -4
- torchzero/tensorlist.py +8 -1
- {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/METADATA +1 -1
- torchzero-0.1.5.dist-info/RECORD +104 -0
- torchzero-0.1.3.dist-info/RECORD +0 -104
- {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/LICENSE +0 -0
- {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/WHEEL +0 -0
- {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/top_level.txt +0 -0
|
@@ -16,11 +16,11 @@ class Add(OptimizerModule):
|
|
|
16
16
|
self.value = value
|
|
17
17
|
|
|
18
18
|
@torch.no_grad()
|
|
19
|
-
def _update(self,
|
|
19
|
+
def _update(self, vars, ascent):
|
|
20
20
|
if isinstance(self.value, (int, float)):
|
|
21
21
|
return ascent.add_(self.value)
|
|
22
22
|
|
|
23
|
-
state_copy =
|
|
23
|
+
state_copy = vars.copy(clone_ascent = True)
|
|
24
24
|
v = self.children['value'].return_ascent(state_copy)
|
|
25
25
|
return ascent.add_(v)
|
|
26
26
|
|
|
@@ -36,11 +36,11 @@ class Sub(OptimizerModule):
|
|
|
36
36
|
self.subtrahend = subtrahend
|
|
37
37
|
|
|
38
38
|
@torch.no_grad()
|
|
39
|
-
def _update(self,
|
|
39
|
+
def _update(self, vars, ascent):
|
|
40
40
|
if isinstance(self.subtrahend, (int, float)):
|
|
41
41
|
return ascent.sub_(self.subtrahend)
|
|
42
42
|
|
|
43
|
-
state_copy =
|
|
43
|
+
state_copy = vars.copy(clone_ascent = True)
|
|
44
44
|
subtrahend = self.children['subtrahend'].return_ascent(state_copy)
|
|
45
45
|
return ascent.sub_(subtrahend)
|
|
46
46
|
|
|
@@ -55,11 +55,11 @@ class RSub(OptimizerModule):
|
|
|
55
55
|
self.minuend = minuend
|
|
56
56
|
|
|
57
57
|
@torch.no_grad()
|
|
58
|
-
def _update(self,
|
|
58
|
+
def _update(self, vars, ascent):
|
|
59
59
|
if isinstance(self.minuend, (int, float)):
|
|
60
60
|
return ascent.sub_(self.minuend).neg_()
|
|
61
61
|
|
|
62
|
-
state_copy =
|
|
62
|
+
state_copy = vars.copy(clone_ascent = True)
|
|
63
63
|
minuend = self.children['minuend'].return_ascent(state_copy)
|
|
64
64
|
return ascent.sub_(minuend).neg_()
|
|
65
65
|
|
|
@@ -75,14 +75,14 @@ class Subtract(OptimizerModule):
|
|
|
75
75
|
self._set_child_('subtrahend', subtrahend)
|
|
76
76
|
|
|
77
77
|
@torch.no_grad
|
|
78
|
-
def step(self,
|
|
79
|
-
state_copy =
|
|
78
|
+
def step(self, vars):
|
|
79
|
+
state_copy = vars.copy(clone_ascent = True)
|
|
80
80
|
minuend = self.children['minuend'].return_ascent(state_copy)
|
|
81
|
-
|
|
82
|
-
subtrahend = self.children['subtrahend'].return_ascent(
|
|
81
|
+
vars.update_attrs_(state_copy)
|
|
82
|
+
subtrahend = self.children['subtrahend'].return_ascent(vars)
|
|
83
83
|
|
|
84
|
-
|
|
85
|
-
return self._update_params_or_step_with_next(
|
|
84
|
+
vars.ascent = minuend.sub_(subtrahend)
|
|
85
|
+
return self._update_params_or_step_with_next(vars)
|
|
86
86
|
|
|
87
87
|
class Mul(OptimizerModule):
|
|
88
88
|
"""multiplies update by `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
|
|
@@ -95,11 +95,11 @@ class Mul(OptimizerModule):
|
|
|
95
95
|
self.value = value
|
|
96
96
|
|
|
97
97
|
@torch.no_grad()
|
|
98
|
-
def _update(self,
|
|
98
|
+
def _update(self, vars, ascent):
|
|
99
99
|
if isinstance(self.value, (int, float)):
|
|
100
100
|
return ascent.mul_(self.value)
|
|
101
101
|
|
|
102
|
-
state_copy =
|
|
102
|
+
state_copy = vars.copy(clone_ascent = True)
|
|
103
103
|
v = self.children['value'].return_ascent(state_copy)
|
|
104
104
|
return ascent.mul_(v)
|
|
105
105
|
|
|
@@ -115,11 +115,11 @@ class Div(OptimizerModule):
|
|
|
115
115
|
self.denominator = denominator
|
|
116
116
|
|
|
117
117
|
@torch.no_grad()
|
|
118
|
-
def _update(self,
|
|
118
|
+
def _update(self, vars, ascent):
|
|
119
119
|
if isinstance(self.denominator, (int, float)):
|
|
120
120
|
return ascent.div_(self.denominator)
|
|
121
121
|
|
|
122
|
-
state_copy =
|
|
122
|
+
state_copy = vars.copy(clone_ascent = True)
|
|
123
123
|
denominator = self.children['denominator'].return_ascent(state_copy)
|
|
124
124
|
return ascent.div_(denominator)
|
|
125
125
|
|
|
@@ -134,11 +134,11 @@ class RDiv(OptimizerModule):
|
|
|
134
134
|
self.numerator = numerator
|
|
135
135
|
|
|
136
136
|
@torch.no_grad()
|
|
137
|
-
def _update(self,
|
|
137
|
+
def _update(self, vars, ascent):
|
|
138
138
|
if isinstance(self.numerator, (int, float)):
|
|
139
139
|
return ascent.reciprocal_().mul_(self.numerator)
|
|
140
140
|
|
|
141
|
-
state_copy =
|
|
141
|
+
state_copy = vars.copy(clone_ascent = True)
|
|
142
142
|
numerator = self.children['numerator'].return_ascent(state_copy)
|
|
143
143
|
return ascent.reciprocal_().mul_(numerator)
|
|
144
144
|
|
|
@@ -154,14 +154,14 @@ class Divide(OptimizerModule):
|
|
|
154
154
|
self._set_child_('denominator', denominator)
|
|
155
155
|
|
|
156
156
|
@torch.no_grad
|
|
157
|
-
def step(self,
|
|
158
|
-
state_copy =
|
|
157
|
+
def step(self, vars):
|
|
158
|
+
state_copy = vars.copy(clone_ascent = True)
|
|
159
159
|
numerator = self.children['numerator'].return_ascent(state_copy)
|
|
160
|
-
|
|
161
|
-
denominator = self.children['denominator'].return_ascent(
|
|
160
|
+
vars.update_attrs_(state_copy)
|
|
161
|
+
denominator = self.children['denominator'].return_ascent(vars)
|
|
162
162
|
|
|
163
|
-
|
|
164
|
-
return self._update_params_or_step_with_next(
|
|
163
|
+
vars.ascent = numerator.div_(denominator)
|
|
164
|
+
return self._update_params_or_step_with_next(vars)
|
|
165
165
|
|
|
166
166
|
|
|
167
167
|
class Pow(OptimizerModule):
|
|
@@ -175,11 +175,11 @@ class Pow(OptimizerModule):
|
|
|
175
175
|
self.power = power
|
|
176
176
|
|
|
177
177
|
@torch.no_grad()
|
|
178
|
-
def _update(self,
|
|
178
|
+
def _update(self, vars, ascent):
|
|
179
179
|
if isinstance(self.power, (int, float)):
|
|
180
180
|
return ascent.pow_(self.power)
|
|
181
181
|
|
|
182
|
-
state_copy =
|
|
182
|
+
state_copy = vars.copy(clone_ascent = True)
|
|
183
183
|
power = self.children['power'].return_ascent(state_copy)
|
|
184
184
|
return ascent.pow_(power)
|
|
185
185
|
|
|
@@ -194,11 +194,11 @@ class RPow(OptimizerModule):
|
|
|
194
194
|
self.base = base
|
|
195
195
|
|
|
196
196
|
@torch.no_grad()
|
|
197
|
-
def _update(self,
|
|
197
|
+
def _update(self, vars, ascent):
|
|
198
198
|
if isinstance(self.base, (int, float)):
|
|
199
199
|
return self.base ** ascent
|
|
200
200
|
|
|
201
|
-
state_copy =
|
|
201
|
+
state_copy = vars.copy(clone_ascent = True)
|
|
202
202
|
base = self.children['base'].return_ascent(state_copy)
|
|
203
203
|
return base.pow_(ascent)
|
|
204
204
|
|
|
@@ -214,14 +214,14 @@ class Power(OptimizerModule):
|
|
|
214
214
|
self._set_child_('power', power)
|
|
215
215
|
|
|
216
216
|
@torch.no_grad
|
|
217
|
-
def step(self,
|
|
218
|
-
state_copy =
|
|
217
|
+
def step(self, vars):
|
|
218
|
+
state_copy = vars.copy(clone_ascent = True)
|
|
219
219
|
base = self.children['base'].return_ascent(state_copy)
|
|
220
|
-
|
|
221
|
-
power = self.children['power'].return_ascent(
|
|
220
|
+
vars.update_attrs_(state_copy)
|
|
221
|
+
power = self.children['power'].return_ascent(vars)
|
|
222
222
|
|
|
223
|
-
|
|
224
|
-
return self._update_params_or_step_with_next(
|
|
223
|
+
vars.ascent = base.pow_(power)
|
|
224
|
+
return self._update_params_or_step_with_next(vars)
|
|
225
225
|
|
|
226
226
|
|
|
227
227
|
class Lerp(OptimizerModule):
|
|
@@ -235,9 +235,9 @@ class Lerp(OptimizerModule):
|
|
|
235
235
|
self.weight = weight
|
|
236
236
|
|
|
237
237
|
@torch.no_grad()
|
|
238
|
-
def _update(self,
|
|
238
|
+
def _update(self, vars, ascent):
|
|
239
239
|
|
|
240
|
-
state_copy =
|
|
240
|
+
state_copy = vars.copy(clone_ascent = True)
|
|
241
241
|
end = self.children['end'].return_ascent(state_copy)
|
|
242
242
|
return ascent.lerp_(end, self.weight)
|
|
243
243
|
|
|
@@ -259,15 +259,15 @@ class Interpolate(OptimizerModule):
|
|
|
259
259
|
self.weight = weight
|
|
260
260
|
|
|
261
261
|
@torch.no_grad
|
|
262
|
-
def step(self,
|
|
263
|
-
state_copy =
|
|
262
|
+
def step(self, vars):
|
|
263
|
+
state_copy = vars.copy(clone_ascent = True)
|
|
264
264
|
input = self.children['input'].return_ascent(state_copy)
|
|
265
|
-
|
|
266
|
-
end = self.children['end'].return_ascent(
|
|
265
|
+
vars.update_attrs_(state_copy)
|
|
266
|
+
end = self.children['end'].return_ascent(vars)
|
|
267
267
|
|
|
268
|
-
|
|
268
|
+
vars.ascent = input.lerp_(end, weight = self.weight)
|
|
269
269
|
|
|
270
|
-
return self._update_params_or_step_with_next(
|
|
270
|
+
return self._update_params_or_step_with_next(vars)
|
|
271
271
|
|
|
272
272
|
class AddMagnitude(OptimizerModule):
|
|
273
273
|
"""Add `value` multiplied by sign of the ascent, i.e. this adds `value` to the magnitude of the update.
|
|
@@ -288,11 +288,11 @@ class AddMagnitude(OptimizerModule):
|
|
|
288
288
|
self.add_to_zero = add_to_zero
|
|
289
289
|
|
|
290
290
|
@torch.no_grad()
|
|
291
|
-
def _update(self,
|
|
291
|
+
def _update(self, vars, ascent):
|
|
292
292
|
if isinstance(self.value, (int, float)):
|
|
293
293
|
if self.add_to_zero: return ascent.add_(ascent.clamp_magnitude(min=1).sign_().mul_(self.value))
|
|
294
294
|
return ascent.add_(ascent.sign_().mul_(self.value))
|
|
295
295
|
|
|
296
|
-
state_copy =
|
|
296
|
+
state_copy = vars.copy(clone_ascent = True)
|
|
297
297
|
v = self.children['value'].return_ascent(state_copy)
|
|
298
298
|
return ascent.add_(v.abs_().mul_(ascent.sign()))
|
|
@@ -26,26 +26,26 @@ class Sum(OptimizerModule):
|
|
|
26
26
|
self._set_child_(i, module)
|
|
27
27
|
|
|
28
28
|
@torch.no_grad
|
|
29
|
-
def step(self,
|
|
29
|
+
def step(self, vars):
|
|
30
30
|
if len(self.children) == 1:
|
|
31
|
-
|
|
32
|
-
if self.scalar is not None:
|
|
33
|
-
return self._update_params_or_step_with_next(
|
|
31
|
+
vars.ascent = self.children[0].return_ascent(vars)
|
|
32
|
+
if self.scalar is not None: vars.ascent += self.scalar
|
|
33
|
+
return self._update_params_or_step_with_next(vars)
|
|
34
34
|
|
|
35
35
|
sum = None
|
|
36
36
|
for i, c in sorted(self.children.items(), key=lambda x: x[0]):
|
|
37
|
-
if i == len(self.children) - 1: cur_state =
|
|
38
|
-
else: cur_state =
|
|
37
|
+
if i == len(self.children) - 1: cur_state = vars
|
|
38
|
+
else: cur_state = vars.copy(clone_ascent = True)
|
|
39
39
|
|
|
40
40
|
if sum is None: sum = c.return_ascent(cur_state)
|
|
41
41
|
else: sum += c.return_ascent(cur_state)
|
|
42
42
|
|
|
43
|
-
if i != len(self.children) - 1:
|
|
43
|
+
if i != len(self.children) - 1: vars.update_attrs_(cur_state)
|
|
44
44
|
|
|
45
45
|
assert sum is not None
|
|
46
46
|
if self.scalar is not None: sum += self.scalar
|
|
47
|
-
|
|
48
|
-
return self._update_params_or_step_with_next(
|
|
47
|
+
vars.ascent = sum
|
|
48
|
+
return self._update_params_or_step_with_next(vars)
|
|
49
49
|
|
|
50
50
|
class Mean(OptimizerModule):
|
|
51
51
|
"""calculates mean of multiple updates.
|
|
@@ -69,28 +69,28 @@ class Mean(OptimizerModule):
|
|
|
69
69
|
self._set_child_(i, module)
|
|
70
70
|
|
|
71
71
|
@torch.no_grad
|
|
72
|
-
def step(self,
|
|
72
|
+
def step(self, vars):
|
|
73
73
|
if len(self.children) == 1:
|
|
74
|
-
|
|
75
|
-
if self.scalar is not None:
|
|
76
|
-
if self.n_values > 1:
|
|
77
|
-
return self._update_params_or_step_with_next(
|
|
74
|
+
vars.ascent = self.children[0].return_ascent(vars)
|
|
75
|
+
if self.scalar is not None: vars.ascent += self.scalar
|
|
76
|
+
if self.n_values > 1: vars.ascent /= self.n_values
|
|
77
|
+
return self._update_params_or_step_with_next(vars)
|
|
78
78
|
|
|
79
79
|
sum = None
|
|
80
80
|
for i, c in sorted(self.children.items(), key=lambda x: x[0]):
|
|
81
|
-
if i == len(self.children) - 1: cur_state =
|
|
82
|
-
else: cur_state =
|
|
81
|
+
if i == len(self.children) - 1: cur_state = vars
|
|
82
|
+
else: cur_state = vars.copy(clone_ascent = True)
|
|
83
83
|
|
|
84
84
|
if sum is None: sum = c.return_ascent(cur_state)
|
|
85
85
|
else: sum += c.return_ascent(cur_state)
|
|
86
86
|
|
|
87
|
-
if i != len(self.children) - 1:
|
|
87
|
+
if i != len(self.children) - 1: vars.update_attrs_(cur_state)
|
|
88
88
|
|
|
89
89
|
assert sum is not None
|
|
90
90
|
if self.scalar is not None: sum += self.scalar
|
|
91
91
|
if self.n_values > 1: sum /= self.n_values
|
|
92
|
-
|
|
93
|
-
return self._update_params_or_step_with_next(
|
|
92
|
+
vars.ascent = sum
|
|
93
|
+
return self._update_params_or_step_with_next(vars)
|
|
94
94
|
|
|
95
95
|
class Product(OptimizerModule):
|
|
96
96
|
"""calculates product of multiple updates.
|
|
@@ -112,23 +112,23 @@ class Product(OptimizerModule):
|
|
|
112
112
|
self._set_child_(i, module)
|
|
113
113
|
|
|
114
114
|
@torch.no_grad
|
|
115
|
-
def step(self,
|
|
115
|
+
def step(self, vars):
|
|
116
116
|
if len(self.children) == 1:
|
|
117
|
-
|
|
118
|
-
if self.scalar is not None:
|
|
119
|
-
return self._update_params_or_step_with_next(
|
|
117
|
+
vars.ascent = self.children[0].return_ascent(vars)
|
|
118
|
+
if self.scalar is not None: vars.ascent *= self.scalar
|
|
119
|
+
return self._update_params_or_step_with_next(vars)
|
|
120
120
|
|
|
121
121
|
prod = None
|
|
122
122
|
for i, c in sorted(self.children.items(), key=lambda x: x[0]):
|
|
123
|
-
if i == len(self.children) - 1: cur_state =
|
|
124
|
-
else: cur_state =
|
|
123
|
+
if i == len(self.children) - 1: cur_state = vars
|
|
124
|
+
else: cur_state = vars.copy(clone_ascent = True)
|
|
125
125
|
|
|
126
126
|
if prod is None: prod = c.return_ascent(cur_state)
|
|
127
127
|
else: prod *= c.return_ascent(cur_state)
|
|
128
128
|
|
|
129
|
-
if i != len(self.children) - 1:
|
|
129
|
+
if i != len(self.children) - 1: vars.update_attrs_(cur_state)
|
|
130
130
|
|
|
131
131
|
assert prod is not None
|
|
132
132
|
if self.scalar is not None: prod *= self.scalar
|
|
133
|
-
|
|
134
|
-
return self._update_params_or_step_with_next(
|
|
133
|
+
vars.ascent = prod
|
|
134
|
+
return self._update_params_or_step_with_next(vars)
|
|
@@ -18,7 +18,7 @@ class Operation(OptimizerModule):
|
|
|
18
18
|
self.operation = methodcaller(f'{operation}_')
|
|
19
19
|
|
|
20
20
|
@torch.no_grad
|
|
21
|
-
def _update(self,
|
|
21
|
+
def _update(self, vars, ascent): return self.operation(ascent)
|
|
22
22
|
|
|
23
23
|
class Reciprocal(OptimizerModule):
|
|
24
24
|
"""*1 / update*"""
|
|
@@ -26,7 +26,7 @@ class Reciprocal(OptimizerModule):
|
|
|
26
26
|
super().__init__({})
|
|
27
27
|
|
|
28
28
|
@torch.no_grad()
|
|
29
|
-
def _update(self,
|
|
29
|
+
def _update(self, vars, ascent): return ascent.reciprocal_()
|
|
30
30
|
|
|
31
31
|
class Negate(OptimizerModule):
|
|
32
32
|
"""minus update"""
|
|
@@ -34,7 +34,7 @@ class Negate(OptimizerModule):
|
|
|
34
34
|
super().__init__({})
|
|
35
35
|
|
|
36
36
|
@torch.no_grad()
|
|
37
|
-
def _update(self,
|
|
37
|
+
def _update(self, vars, ascent): return ascent.neg_()
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
def sign_grad_(params: Iterable[torch.Tensor]):
|
|
@@ -51,7 +51,7 @@ class Sign(OptimizerModule):
|
|
|
51
51
|
super().__init__({})
|
|
52
52
|
|
|
53
53
|
@torch.no_grad
|
|
54
|
-
def _update(self,
|
|
54
|
+
def _update(self, vars, ascent): return ascent.sign_()
|
|
55
55
|
|
|
56
56
|
class Abs(OptimizerModule):
|
|
57
57
|
"""takes absolute values of the update."""
|
|
@@ -59,7 +59,7 @@ class Abs(OptimizerModule):
|
|
|
59
59
|
super().__init__({})
|
|
60
60
|
|
|
61
61
|
@torch.no_grad
|
|
62
|
-
def _update(self,
|
|
62
|
+
def _update(self, vars, ascent): return ascent.abs_()
|
|
63
63
|
|
|
64
64
|
class Sin(OptimizerModule):
|
|
65
65
|
"""applies sin function to the ascent"""
|
|
@@ -67,7 +67,7 @@ class Sin(OptimizerModule):
|
|
|
67
67
|
super().__init__({})
|
|
68
68
|
|
|
69
69
|
@torch.no_grad
|
|
70
|
-
def _update(self,
|
|
70
|
+
def _update(self, vars, ascent): return ascent.sin_()
|
|
71
71
|
|
|
72
72
|
class Cos(OptimizerModule):
|
|
73
73
|
"""applies cos function to the ascent"""
|
|
@@ -75,7 +75,7 @@ class Cos(OptimizerModule):
|
|
|
75
75
|
super().__init__({})
|
|
76
76
|
|
|
77
77
|
@torch.no_grad
|
|
78
|
-
def _update(self,
|
|
78
|
+
def _update(self, vars, ascent): return ascent.cos_()
|
|
79
79
|
|
|
80
80
|
|
|
81
81
|
class NanToNum(OptimizerModule):
|
|
@@ -97,7 +97,7 @@ class NanToNum(OptimizerModule):
|
|
|
97
97
|
self.neginf = neginf
|
|
98
98
|
|
|
99
99
|
@torch.no_grad()
|
|
100
|
-
def _update(self,
|
|
100
|
+
def _update(self, vars, ascent): return ascent.nan_to_num_(self.nan, self.posinf, self.neginf)
|
|
101
101
|
|
|
102
102
|
|
|
103
103
|
class MagnitudePower(OptimizerModule):
|
|
@@ -107,7 +107,7 @@ class MagnitudePower(OptimizerModule):
|
|
|
107
107
|
self.value = value
|
|
108
108
|
|
|
109
109
|
@torch.no_grad()
|
|
110
|
-
def _update(self,
|
|
110
|
+
def _update(self, vars, ascent):
|
|
111
111
|
if self.value % 2 == 1: return ascent.pow_(self.value)
|
|
112
112
|
return ascent.abs().pow_(self.value) * ascent.sign()
|
|
113
113
|
|
|
@@ -31,7 +31,7 @@ class Adagrad(OptimizerModule):
|
|
|
31
31
|
self.cur_step = 0
|
|
32
32
|
|
|
33
33
|
@torch.no_grad
|
|
34
|
-
def _update(self,
|
|
34
|
+
def _update(self, vars, ascent):
|
|
35
35
|
settings = self.get_all_group_keys()
|
|
36
36
|
if self.cur_step == 0: init = ascent.full_like(settings['initial_accumulator_value'])
|
|
37
37
|
else: init = None
|
|
@@ -48,7 +48,7 @@ class Adam(OptimizerModule):
|
|
|
48
48
|
self.amsgrad = amsgrad
|
|
49
49
|
|
|
50
50
|
@torch.no_grad
|
|
51
|
-
def step(self,
|
|
51
|
+
def step(self, vars):
|
|
52
52
|
# Adam step is a bit differet from other optimizer steps
|
|
53
53
|
# due to how common it is, I implemented two additional optimizations,
|
|
54
54
|
|
|
@@ -85,14 +85,14 @@ class Adam(OptimizerModule):
|
|
|
85
85
|
alpha = settings['alpha']
|
|
86
86
|
|
|
87
87
|
# get params if ascent is None so we need params to access their gradient as initial ascent
|
|
88
|
-
if
|
|
88
|
+
if vars.ascent is None:
|
|
89
89
|
if params is None: pg = self.get_params()
|
|
90
90
|
else: pg = params
|
|
91
91
|
else:
|
|
92
92
|
pg = None
|
|
93
93
|
|
|
94
94
|
ret = _adam_step(
|
|
95
|
-
ascent=
|
|
95
|
+
ascent=vars.maybe_use_grad_(pg),
|
|
96
96
|
exp_avg = exp_avg,
|
|
97
97
|
exp_avg_sq = exp_avg_sq,
|
|
98
98
|
alpha = alpha,
|
|
@@ -107,12 +107,12 @@ class Adam(OptimizerModule):
|
|
|
107
107
|
self.cur_step += 1
|
|
108
108
|
if params is None:
|
|
109
109
|
assert ret is not None
|
|
110
|
-
|
|
111
|
-
return self._update_params_or_step_with_next(
|
|
110
|
+
vars.ascent = ret
|
|
111
|
+
return self._update_params_or_step_with_next(vars)
|
|
112
112
|
|
|
113
113
|
# next module is either None or LR
|
|
114
|
-
if self.next_module is None: return
|
|
114
|
+
if self.next_module is None: return vars.get_loss()
|
|
115
115
|
|
|
116
116
|
# step with LR, which has _skip = True so it won't apply lr, but may step with the scheduler
|
|
117
|
-
self.next_module._update(
|
|
118
|
-
return
|
|
117
|
+
self.next_module._update(vars, None) # type:ignore
|
|
118
|
+
return vars.get_loss()
|
|
@@ -22,7 +22,7 @@ class Lion(OptimizerModule):
|
|
|
22
22
|
super().__init__(defaults)
|
|
23
23
|
|
|
24
24
|
@torch.no_grad
|
|
25
|
-
def _update(self,
|
|
25
|
+
def _update(self, vars, ascent):
|
|
26
26
|
beta1, beta2 = self.get_group_keys('beta1', 'beta2')
|
|
27
27
|
ema = self.get_state_key('ema')
|
|
28
28
|
return _lion_step_(ascent,ema,beta1,beta2)
|
|
@@ -40,7 +40,7 @@ class RMSProp(OptimizerModule):
|
|
|
40
40
|
self.centered = centered
|
|
41
41
|
|
|
42
42
|
@torch.no_grad
|
|
43
|
-
def _update(self,
|
|
43
|
+
def _update(self, vars, ascent):
|
|
44
44
|
settings = self.get_all_group_keys()
|
|
45
45
|
if self.centered:
|
|
46
46
|
mean, mean_sqr = self.get_state_keys('mean', 'mean_sqr')
|
|
@@ -15,7 +15,7 @@ class SGD(OptimizerModule):
|
|
|
15
15
|
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
16
16
|
nesterov (bool, optional):
|
|
17
17
|
enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
|
|
18
|
-
alpha (float, optional): learning rate. Defaults to
|
|
18
|
+
alpha (float, optional): learning rate. Defaults to 1.
|
|
19
19
|
"""
|
|
20
20
|
def __init__(
|
|
21
21
|
self,
|
|
@@ -32,7 +32,7 @@ class SGD(OptimizerModule):
|
|
|
32
32
|
self.current_step = 0
|
|
33
33
|
|
|
34
34
|
@torch.no_grad
|
|
35
|
-
def _update(self,
|
|
35
|
+
def _update(self, vars, ascent):
|
|
36
36
|
params = self.get_params()
|
|
37
37
|
settings = self.get_all_group_keys()
|
|
38
38
|
|
|
@@ -116,7 +116,7 @@ class ZeropowerViaNewtonSchulz(OptimizerModule):
|
|
|
116
116
|
if compiled: self._zeropower_via_newtonschulz5 = _compiled_zeropower_via_newtonschulz5
|
|
117
117
|
else: self._zeropower_via_newtonschulz5 = _zeropower_via_newtonschulz5
|
|
118
118
|
|
|
119
|
-
def _update(self,
|
|
119
|
+
def _update(self, vars, ascent):
|
|
120
120
|
toggle, ns_steps, adaptive = self.get_group_keys('newtonshultz', 'ns_steps', 'adaptive', cls=list)
|
|
121
121
|
|
|
122
122
|
for asc, enable, steps, ada in zip(ascent, toggle, ns_steps, adaptive):
|
|
@@ -146,11 +146,11 @@ class DualNormCorrection(OptimizerModule):
|
|
|
146
146
|
defaults = dict(adaptive_scale_min = adaptive_scale_min, adaptive_scale_max = adaptive_scale_max)
|
|
147
147
|
super().__init__(defaults)
|
|
148
148
|
|
|
149
|
-
def _update(self,
|
|
149
|
+
def _update(self, vars, ascent):
|
|
150
150
|
params = self.get_params()
|
|
151
151
|
adaptive_scale_min, adaptive_scale_max = self.get_group_keys('adaptive_scale_min', 'adaptive_scale_max')
|
|
152
152
|
|
|
153
|
-
for asc, grad, min, max in zip(ascent,
|
|
153
|
+
for asc, grad, min, max in zip(ascent, vars.maybe_compute_grad_(params), adaptive_scale_min, adaptive_scale_max):
|
|
154
154
|
if len([i for i in asc.shape if i > 1]) != 0:
|
|
155
155
|
scale = torch.einsum('ij,ij->', grad.view(grad.shape[0], -1), asc.view(asc.shape[0], -1))
|
|
156
156
|
if min is not None or max is not None: scale = scale.clip(min, max)
|
|
@@ -80,7 +80,7 @@ class Orthogonalize(OptimizerModule):
|
|
|
80
80
|
super().__init__(defaults, target = target)
|
|
81
81
|
self.warn_fail = warn_fail
|
|
82
82
|
|
|
83
|
-
def _update(self,
|
|
83
|
+
def _update(self, vars, ascent):
|
|
84
84
|
toggle = self.get_group_key('orth', cls=list)
|
|
85
85
|
_orthogonalize_update_(ascent, toggle, self.warn_fail)
|
|
86
86
|
return ascent
|
|
@@ -18,7 +18,7 @@ def add_noise_(
|
|
|
18
18
|
grads += grads.sample_like(alpha, distribution)
|
|
19
19
|
|
|
20
20
|
elif mode == 'global':
|
|
21
|
-
grads += grads.sample_like((grads.total_vector_norm(1)/grads.total_numel() * alpha).detach().cpu().item(), distribution)
|
|
21
|
+
grads += grads.sample_like((grads.total_vector_norm(1)/grads.total_numel() * alpha).detach().cpu().item(), distribution) # type:ignore
|
|
22
22
|
|
|
23
23
|
elif mode == 'param':
|
|
24
24
|
grads += grads.sample_like(grads.abs().mean()*alpha, distribution)
|
|
@@ -57,7 +57,7 @@ class AddNoise(OptimizerModule):
|
|
|
57
57
|
self.mode: Literal["absolute", "global", "param", "channel"] = mode
|
|
58
58
|
|
|
59
59
|
@torch.no_grad
|
|
60
|
-
def _update(self,
|
|
60
|
+
def _update(self, vars, ascent):
|
|
61
61
|
alpha = self.get_group_key('alpha')
|
|
62
62
|
|
|
63
63
|
add_noise_(ascent, alpha, self.distribution, self.mode)
|
|
@@ -72,6 +72,6 @@ class Random(OptimizerModule):
|
|
|
72
72
|
self.distribution: Distributions = distribution
|
|
73
73
|
|
|
74
74
|
@torch.no_grad
|
|
75
|
-
def _update(self,
|
|
75
|
+
def _update(self, vars, ascent):
|
|
76
76
|
alpha = self.get_group_key('alpha')
|
|
77
77
|
return ascent.sample_like(alpha, self.distribution)
|
|
@@ -29,7 +29,7 @@ def _normalize_grad_(
|
|
|
29
29
|
if not isinstance(grads, TensorList): grads = TensorList(grads)
|
|
30
30
|
norm = grads.total_vector_norm(ord)
|
|
31
31
|
if norm > min:
|
|
32
|
-
grads /= norm / norm_value
|
|
32
|
+
grads /= norm / norm_value # type:ignore
|
|
33
33
|
|
|
34
34
|
@torch.no_grad
|
|
35
35
|
def normalize_grad_(
|
|
@@ -112,7 +112,7 @@ class Normalize(OptimizerModule):
|
|
|
112
112
|
self.min_numel = min_numel
|
|
113
113
|
|
|
114
114
|
@torch.no_grad
|
|
115
|
-
def _update(self,
|
|
115
|
+
def _update(self, vars, ascent):
|
|
116
116
|
_normalize_grad_(
|
|
117
117
|
ascent,
|
|
118
118
|
norm_value = self.norm_value,
|
|
@@ -225,7 +225,7 @@ class Centralize(OptimizerModule):
|
|
|
225
225
|
self.min_numel = min_numel
|
|
226
226
|
|
|
227
227
|
@torch.no_grad
|
|
228
|
-
def _update(self,
|
|
228
|
+
def _update(self, vars, ascent):
|
|
229
229
|
_centralize_grad_(
|
|
230
230
|
ascent,
|
|
231
231
|
mode = self.mode,
|
|
@@ -258,7 +258,7 @@ class ClipValue(OptimizerModule):
|
|
|
258
258
|
super().__init__(defaults)
|
|
259
259
|
|
|
260
260
|
@torch.no_grad
|
|
261
|
-
def _update(self,
|
|
261
|
+
def _update(self, vars, ascent):
|
|
262
262
|
value = self.get_group_key('value')
|
|
263
263
|
ascent.clamp_(-value, value)
|
|
264
264
|
return ascent
|
|
@@ -317,7 +317,7 @@ class ClipNorm(OptimizerModule):
|
|
|
317
317
|
self.mode: typing.Literal["global", "param", "channel"] = mode
|
|
318
318
|
|
|
319
319
|
@torch.no_grad
|
|
320
|
-
def _update(self,
|
|
320
|
+
def _update(self, vars, ascent):
|
|
321
321
|
_normalize_grad_(
|
|
322
322
|
ascent,
|
|
323
323
|
norm_value = self.max_norm,
|
|
@@ -58,7 +58,7 @@ class OrthoGrad(OptimizerModule):
|
|
|
58
58
|
self.renormalize = renormalize
|
|
59
59
|
self.sqrt_scale = sqrt_scale
|
|
60
60
|
|
|
61
|
-
def _update(self,
|
|
61
|
+
def _update(self, vars, ascent):
|
|
62
62
|
params = self.get_params()
|
|
63
63
|
|
|
64
64
|
if self.renormalize: orig_norm = ascent.norm(2) + self.eps
|
|
@@ -81,7 +81,7 @@ if TYPE_CHECKING:
|
|
|
81
81
|
|
|
82
82
|
# self.id = random.random()
|
|
83
83
|
|
|
84
|
-
# def step(self,
|
|
84
|
+
# def step(self, vars):
|
|
85
85
|
# if self.cur % self.update_every == 0:
|
|
86
86
|
# self.scheduler_step_fn()
|
|
87
87
|
# self.cur_lr = self.dummy_opt.first_param_group['lr']
|
|
@@ -113,7 +113,7 @@ class LRWarmup(OptimizerModule):
|
|
|
113
113
|
|
|
114
114
|
self.cur = 0
|
|
115
115
|
|
|
116
|
-
def _update(self,
|
|
116
|
+
def _update(self, vars, ascent):
|
|
117
117
|
if self.cur < self.delay_steps:
|
|
118
118
|
if self.start_lr != 1: ascent *= self.start_lr
|
|
119
119
|
|