torchzero 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. torchzero/core/__init__.py +1 -1
  2. torchzero/core/module.py +72 -49
  3. torchzero/core/tensorlist_optimizer.py +1 -1
  4. torchzero/modules/adaptive/adaptive.py +11 -11
  5. torchzero/modules/experimental/experimental.py +41 -41
  6. torchzero/modules/experimental/quad_interp.py +8 -8
  7. torchzero/modules/experimental/subspace.py +37 -37
  8. torchzero/modules/gradient_approximation/base_approximator.py +19 -24
  9. torchzero/modules/gradient_approximation/fdm.py +1 -1
  10. torchzero/modules/gradient_approximation/newton_fdm.py +13 -13
  11. torchzero/modules/gradient_approximation/rfdm.py +1 -1
  12. torchzero/modules/line_search/armijo.py +8 -8
  13. torchzero/modules/line_search/base_ls.py +8 -8
  14. torchzero/modules/line_search/directional_newton.py +14 -14
  15. torchzero/modules/line_search/grid_ls.py +7 -7
  16. torchzero/modules/line_search/scipy_minimize_scalar.py +3 -3
  17. torchzero/modules/meta/alternate.py +4 -4
  18. torchzero/modules/meta/grafting.py +23 -23
  19. torchzero/modules/meta/optimizer_wrapper.py +14 -14
  20. torchzero/modules/meta/return_overrides.py +8 -8
  21. torchzero/modules/misc/accumulate.py +6 -6
  22. torchzero/modules/misc/basic.py +16 -16
  23. torchzero/modules/misc/lr.py +2 -2
  24. torchzero/modules/misc/multistep.py +7 -7
  25. torchzero/modules/misc/on_increase.py +9 -9
  26. torchzero/modules/momentum/momentum.py +4 -4
  27. torchzero/modules/operations/multi.py +44 -44
  28. torchzero/modules/operations/reduction.py +28 -28
  29. torchzero/modules/operations/singular.py +9 -9
  30. torchzero/modules/optimizers/adagrad.py +1 -1
  31. torchzero/modules/optimizers/adam.py +8 -8
  32. torchzero/modules/optimizers/lion.py +1 -1
  33. torchzero/modules/optimizers/rmsprop.py +1 -1
  34. torchzero/modules/optimizers/rprop.py +1 -1
  35. torchzero/modules/optimizers/sgd.py +2 -2
  36. torchzero/modules/orthogonalization/newtonschulz.py +3 -3
  37. torchzero/modules/orthogonalization/svd.py +1 -1
  38. torchzero/modules/regularization/dropout.py +1 -1
  39. torchzero/modules/regularization/noise.py +3 -3
  40. torchzero/modules/regularization/normalization.py +5 -5
  41. torchzero/modules/regularization/ortho_grad.py +1 -1
  42. torchzero/modules/regularization/weight_decay.py +1 -1
  43. torchzero/modules/scheduling/lr_schedulers.py +2 -2
  44. torchzero/modules/scheduling/step_size.py +8 -8
  45. torchzero/modules/second_order/newton.py +12 -12
  46. torchzero/modules/smoothing/__init__.py +1 -1
  47. torchzero/modules/smoothing/gaussian_smoothing.py +7 -7
  48. torchzero/modules/smoothing/laplacian_smoothing.py +1 -1
  49. torchzero/modules/weight_averaging/ema.py +3 -3
  50. torchzero/modules/weight_averaging/swa.py +8 -8
  51. torchzero/optim/first_order/forward_gradient.py +1 -1
  52. torchzero/optim/modular.py +4 -4
  53. torchzero/tensorlist.py +8 -1
  54. {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/METADATA +1 -1
  55. torchzero-0.1.5.dist-info/RECORD +104 -0
  56. torchzero-0.1.3.dist-info/RECORD +0 -104
  57. {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/LICENSE +0 -0
  58. {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/WHEEL +0 -0
  59. {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/top_level.txt +0 -0
@@ -16,11 +16,11 @@ class Add(OptimizerModule):
16
16
  self.value = value
17
17
 
18
18
  @torch.no_grad()
19
- def _update(self, state, ascent):
19
+ def _update(self, vars, ascent):
20
20
  if isinstance(self.value, (int, float)):
21
21
  return ascent.add_(self.value)
22
22
 
23
- state_copy = state.copy(clone_ascent = True)
23
+ state_copy = vars.copy(clone_ascent = True)
24
24
  v = self.children['value'].return_ascent(state_copy)
25
25
  return ascent.add_(v)
26
26
 
@@ -36,11 +36,11 @@ class Sub(OptimizerModule):
36
36
  self.subtrahend = subtrahend
37
37
 
38
38
  @torch.no_grad()
39
- def _update(self, state, ascent):
39
+ def _update(self, vars, ascent):
40
40
  if isinstance(self.subtrahend, (int, float)):
41
41
  return ascent.sub_(self.subtrahend)
42
42
 
43
- state_copy = state.copy(clone_ascent = True)
43
+ state_copy = vars.copy(clone_ascent = True)
44
44
  subtrahend = self.children['subtrahend'].return_ascent(state_copy)
45
45
  return ascent.sub_(subtrahend)
46
46
 
@@ -55,11 +55,11 @@ class RSub(OptimizerModule):
55
55
  self.minuend = minuend
56
56
 
57
57
  @torch.no_grad()
58
- def _update(self, state, ascent):
58
+ def _update(self, vars, ascent):
59
59
  if isinstance(self.minuend, (int, float)):
60
60
  return ascent.sub_(self.minuend).neg_()
61
61
 
62
- state_copy = state.copy(clone_ascent = True)
62
+ state_copy = vars.copy(clone_ascent = True)
63
63
  minuend = self.children['minuend'].return_ascent(state_copy)
64
64
  return ascent.sub_(minuend).neg_()
65
65
 
@@ -75,14 +75,14 @@ class Subtract(OptimizerModule):
75
75
  self._set_child_('subtrahend', subtrahend)
76
76
 
77
77
  @torch.no_grad
78
- def step(self, state):
79
- state_copy = state.copy(clone_ascent = True)
78
+ def step(self, vars):
79
+ state_copy = vars.copy(clone_ascent = True)
80
80
  minuend = self.children['minuend'].return_ascent(state_copy)
81
- state.update_attrs_(state_copy)
82
- subtrahend = self.children['subtrahend'].return_ascent(state)
81
+ vars.update_attrs_(state_copy)
82
+ subtrahend = self.children['subtrahend'].return_ascent(vars)
83
83
 
84
- state.ascent = minuend.sub_(subtrahend)
85
- return self._update_params_or_step_with_next(state)
84
+ vars.ascent = minuend.sub_(subtrahend)
85
+ return self._update_params_or_step_with_next(vars)
86
86
 
87
87
  class Mul(OptimizerModule):
88
88
  """multiplies update by `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
@@ -95,11 +95,11 @@ class Mul(OptimizerModule):
95
95
  self.value = value
96
96
 
97
97
  @torch.no_grad()
98
- def _update(self, state, ascent):
98
+ def _update(self, vars, ascent):
99
99
  if isinstance(self.value, (int, float)):
100
100
  return ascent.mul_(self.value)
101
101
 
102
- state_copy = state.copy(clone_ascent = True)
102
+ state_copy = vars.copy(clone_ascent = True)
103
103
  v = self.children['value'].return_ascent(state_copy)
104
104
  return ascent.mul_(v)
105
105
 
@@ -115,11 +115,11 @@ class Div(OptimizerModule):
115
115
  self.denominator = denominator
116
116
 
117
117
  @torch.no_grad()
118
- def _update(self, state, ascent):
118
+ def _update(self, vars, ascent):
119
119
  if isinstance(self.denominator, (int, float)):
120
120
  return ascent.div_(self.denominator)
121
121
 
122
- state_copy = state.copy(clone_ascent = True)
122
+ state_copy = vars.copy(clone_ascent = True)
123
123
  denominator = self.children['denominator'].return_ascent(state_copy)
124
124
  return ascent.div_(denominator)
125
125
 
@@ -134,11 +134,11 @@ class RDiv(OptimizerModule):
134
134
  self.numerator = numerator
135
135
 
136
136
  @torch.no_grad()
137
- def _update(self, state, ascent):
137
+ def _update(self, vars, ascent):
138
138
  if isinstance(self.numerator, (int, float)):
139
139
  return ascent.reciprocal_().mul_(self.numerator)
140
140
 
141
- state_copy = state.copy(clone_ascent = True)
141
+ state_copy = vars.copy(clone_ascent = True)
142
142
  numerator = self.children['numerator'].return_ascent(state_copy)
143
143
  return ascent.reciprocal_().mul_(numerator)
144
144
 
@@ -154,14 +154,14 @@ class Divide(OptimizerModule):
154
154
  self._set_child_('denominator', denominator)
155
155
 
156
156
  @torch.no_grad
157
- def step(self, state):
158
- state_copy = state.copy(clone_ascent = True)
157
+ def step(self, vars):
158
+ state_copy = vars.copy(clone_ascent = True)
159
159
  numerator = self.children['numerator'].return_ascent(state_copy)
160
- state.update_attrs_(state_copy)
161
- denominator = self.children['denominator'].return_ascent(state)
160
+ vars.update_attrs_(state_copy)
161
+ denominator = self.children['denominator'].return_ascent(vars)
162
162
 
163
- state.ascent = numerator.div_(denominator)
164
- return self._update_params_or_step_with_next(state)
163
+ vars.ascent = numerator.div_(denominator)
164
+ return self._update_params_or_step_with_next(vars)
165
165
 
166
166
 
167
167
  class Pow(OptimizerModule):
@@ -175,11 +175,11 @@ class Pow(OptimizerModule):
175
175
  self.power = power
176
176
 
177
177
  @torch.no_grad()
178
- def _update(self, state, ascent):
178
+ def _update(self, vars, ascent):
179
179
  if isinstance(self.power, (int, float)):
180
180
  return ascent.pow_(self.power)
181
181
 
182
- state_copy = state.copy(clone_ascent = True)
182
+ state_copy = vars.copy(clone_ascent = True)
183
183
  power = self.children['power'].return_ascent(state_copy)
184
184
  return ascent.pow_(power)
185
185
 
@@ -194,11 +194,11 @@ class RPow(OptimizerModule):
194
194
  self.base = base
195
195
 
196
196
  @torch.no_grad()
197
- def _update(self, state, ascent):
197
+ def _update(self, vars, ascent):
198
198
  if isinstance(self.base, (int, float)):
199
199
  return self.base ** ascent
200
200
 
201
- state_copy = state.copy(clone_ascent = True)
201
+ state_copy = vars.copy(clone_ascent = True)
202
202
  base = self.children['base'].return_ascent(state_copy)
203
203
  return base.pow_(ascent)
204
204
 
@@ -214,14 +214,14 @@ class Power(OptimizerModule):
214
214
  self._set_child_('power', power)
215
215
 
216
216
  @torch.no_grad
217
- def step(self, state):
218
- state_copy = state.copy(clone_ascent = True)
217
+ def step(self, vars):
218
+ state_copy = vars.copy(clone_ascent = True)
219
219
  base = self.children['base'].return_ascent(state_copy)
220
- state.update_attrs_(state_copy)
221
- power = self.children['power'].return_ascent(state)
220
+ vars.update_attrs_(state_copy)
221
+ power = self.children['power'].return_ascent(vars)
222
222
 
223
- state.ascent = base.pow_(power)
224
- return self._update_params_or_step_with_next(state)
223
+ vars.ascent = base.pow_(power)
224
+ return self._update_params_or_step_with_next(vars)
225
225
 
226
226
 
227
227
  class Lerp(OptimizerModule):
@@ -235,9 +235,9 @@ class Lerp(OptimizerModule):
235
235
  self.weight = weight
236
236
 
237
237
  @torch.no_grad()
238
- def _update(self, state, ascent):
238
+ def _update(self, vars, ascent):
239
239
 
240
- state_copy = state.copy(clone_ascent = True)
240
+ state_copy = vars.copy(clone_ascent = True)
241
241
  end = self.children['end'].return_ascent(state_copy)
242
242
  return ascent.lerp_(end, self.weight)
243
243
 
@@ -259,15 +259,15 @@ class Interpolate(OptimizerModule):
259
259
  self.weight = weight
260
260
 
261
261
  @torch.no_grad
262
- def step(self, state):
263
- state_copy = state.copy(clone_ascent = True)
262
+ def step(self, vars):
263
+ state_copy = vars.copy(clone_ascent = True)
264
264
  input = self.children['input'].return_ascent(state_copy)
265
- state.update_attrs_(state_copy)
266
- end = self.children['end'].return_ascent(state)
265
+ vars.update_attrs_(state_copy)
266
+ end = self.children['end'].return_ascent(vars)
267
267
 
268
- state.ascent = input.lerp_(end, weight = self.weight)
268
+ vars.ascent = input.lerp_(end, weight = self.weight)
269
269
 
270
- return self._update_params_or_step_with_next(state)
270
+ return self._update_params_or_step_with_next(vars)
271
271
 
272
272
  class AddMagnitude(OptimizerModule):
273
273
  """Add `value` multiplied by sign of the ascent, i.e. this adds `value` to the magnitude of the update.
@@ -288,11 +288,11 @@ class AddMagnitude(OptimizerModule):
288
288
  self.add_to_zero = add_to_zero
289
289
 
290
290
  @torch.no_grad()
291
- def _update(self, state, ascent):
291
+ def _update(self, vars, ascent):
292
292
  if isinstance(self.value, (int, float)):
293
293
  if self.add_to_zero: return ascent.add_(ascent.clamp_magnitude(min=1).sign_().mul_(self.value))
294
294
  return ascent.add_(ascent.sign_().mul_(self.value))
295
295
 
296
- state_copy = state.copy(clone_ascent = True)
296
+ state_copy = vars.copy(clone_ascent = True)
297
297
  v = self.children['value'].return_ascent(state_copy)
298
298
  return ascent.add_(v.abs_().mul_(ascent.sign()))
@@ -26,26 +26,26 @@ class Sum(OptimizerModule):
26
26
  self._set_child_(i, module)
27
27
 
28
28
  @torch.no_grad
29
- def step(self, state):
29
+ def step(self, vars):
30
30
  if len(self.children) == 1:
31
- state.ascent = self.children[0].return_ascent(state)
32
- if self.scalar is not None: state.ascent += self.scalar
33
- return self._update_params_or_step_with_next(state)
31
+ vars.ascent = self.children[0].return_ascent(vars)
32
+ if self.scalar is not None: vars.ascent += self.scalar
33
+ return self._update_params_or_step_with_next(vars)
34
34
 
35
35
  sum = None
36
36
  for i, c in sorted(self.children.items(), key=lambda x: x[0]):
37
- if i == len(self.children) - 1: cur_state = state
38
- else: cur_state = state.copy(clone_ascent = True)
37
+ if i == len(self.children) - 1: cur_state = vars
38
+ else: cur_state = vars.copy(clone_ascent = True)
39
39
 
40
40
  if sum is None: sum = c.return_ascent(cur_state)
41
41
  else: sum += c.return_ascent(cur_state)
42
42
 
43
- if i != len(self.children) - 1: state.update_attrs_(cur_state)
43
+ if i != len(self.children) - 1: vars.update_attrs_(cur_state)
44
44
 
45
45
  assert sum is not None
46
46
  if self.scalar is not None: sum += self.scalar
47
- state.ascent = sum
48
- return self._update_params_or_step_with_next(state)
47
+ vars.ascent = sum
48
+ return self._update_params_or_step_with_next(vars)
49
49
 
50
50
  class Mean(OptimizerModule):
51
51
  """calculates mean of multiple updates.
@@ -69,28 +69,28 @@ class Mean(OptimizerModule):
69
69
  self._set_child_(i, module)
70
70
 
71
71
  @torch.no_grad
72
- def step(self, state):
72
+ def step(self, vars):
73
73
  if len(self.children) == 1:
74
- state.ascent = self.children[0].return_ascent(state)
75
- if self.scalar is not None: state.ascent += self.scalar
76
- if self.n_values > 1: state.ascent /= self.n_values
77
- return self._update_params_or_step_with_next(state)
74
+ vars.ascent = self.children[0].return_ascent(vars)
75
+ if self.scalar is not None: vars.ascent += self.scalar
76
+ if self.n_values > 1: vars.ascent /= self.n_values
77
+ return self._update_params_or_step_with_next(vars)
78
78
 
79
79
  sum = None
80
80
  for i, c in sorted(self.children.items(), key=lambda x: x[0]):
81
- if i == len(self.children) - 1: cur_state = state
82
- else: cur_state = state.copy(clone_ascent = True)
81
+ if i == len(self.children) - 1: cur_state = vars
82
+ else: cur_state = vars.copy(clone_ascent = True)
83
83
 
84
84
  if sum is None: sum = c.return_ascent(cur_state)
85
85
  else: sum += c.return_ascent(cur_state)
86
86
 
87
- if i != len(self.children) - 1: state.update_attrs_(cur_state)
87
+ if i != len(self.children) - 1: vars.update_attrs_(cur_state)
88
88
 
89
89
  assert sum is not None
90
90
  if self.scalar is not None: sum += self.scalar
91
91
  if self.n_values > 1: sum /= self.n_values
92
- state.ascent = sum
93
- return self._update_params_or_step_with_next(state)
92
+ vars.ascent = sum
93
+ return self._update_params_or_step_with_next(vars)
94
94
 
95
95
  class Product(OptimizerModule):
96
96
  """calculates product of multiple updates.
@@ -112,23 +112,23 @@ class Product(OptimizerModule):
112
112
  self._set_child_(i, module)
113
113
 
114
114
  @torch.no_grad
115
- def step(self, state):
115
+ def step(self, vars):
116
116
  if len(self.children) == 1:
117
- state.ascent = self.children[0].return_ascent(state)
118
- if self.scalar is not None: state.ascent *= self.scalar
119
- return self._update_params_or_step_with_next(state)
117
+ vars.ascent = self.children[0].return_ascent(vars)
118
+ if self.scalar is not None: vars.ascent *= self.scalar
119
+ return self._update_params_or_step_with_next(vars)
120
120
 
121
121
  prod = None
122
122
  for i, c in sorted(self.children.items(), key=lambda x: x[0]):
123
- if i == len(self.children) - 1: cur_state = state
124
- else: cur_state = state.copy(clone_ascent = True)
123
+ if i == len(self.children) - 1: cur_state = vars
124
+ else: cur_state = vars.copy(clone_ascent = True)
125
125
 
126
126
  if prod is None: prod = c.return_ascent(cur_state)
127
127
  else: prod *= c.return_ascent(cur_state)
128
128
 
129
- if i != len(self.children) - 1: state.update_attrs_(cur_state)
129
+ if i != len(self.children) - 1: vars.update_attrs_(cur_state)
130
130
 
131
131
  assert prod is not None
132
132
  if self.scalar is not None: prod *= self.scalar
133
- state.ascent = prod
134
- return self._update_params_or_step_with_next(state)
133
+ vars.ascent = prod
134
+ return self._update_params_or_step_with_next(vars)
@@ -18,7 +18,7 @@ class Operation(OptimizerModule):
18
18
  self.operation = methodcaller(f'{operation}_')
19
19
 
20
20
  @torch.no_grad
21
- def _update(self, state, ascent): return self.operation(ascent)
21
+ def _update(self, vars, ascent): return self.operation(ascent)
22
22
 
23
23
  class Reciprocal(OptimizerModule):
24
24
  """*1 / update*"""
@@ -26,7 +26,7 @@ class Reciprocal(OptimizerModule):
26
26
  super().__init__({})
27
27
 
28
28
  @torch.no_grad()
29
- def _update(self, state, ascent): return ascent.reciprocal_()
29
+ def _update(self, vars, ascent): return ascent.reciprocal_()
30
30
 
31
31
  class Negate(OptimizerModule):
32
32
  """minus update"""
@@ -34,7 +34,7 @@ class Negate(OptimizerModule):
34
34
  super().__init__({})
35
35
 
36
36
  @torch.no_grad()
37
- def _update(self, state, ascent): return ascent.neg_()
37
+ def _update(self, vars, ascent): return ascent.neg_()
38
38
 
39
39
 
40
40
  def sign_grad_(params: Iterable[torch.Tensor]):
@@ -51,7 +51,7 @@ class Sign(OptimizerModule):
51
51
  super().__init__({})
52
52
 
53
53
  @torch.no_grad
54
- def _update(self, state, ascent): return ascent.sign_()
54
+ def _update(self, vars, ascent): return ascent.sign_()
55
55
 
56
56
  class Abs(OptimizerModule):
57
57
  """takes absolute values of the update."""
@@ -59,7 +59,7 @@ class Abs(OptimizerModule):
59
59
  super().__init__({})
60
60
 
61
61
  @torch.no_grad
62
- def _update(self, state, ascent): return ascent.abs_()
62
+ def _update(self, vars, ascent): return ascent.abs_()
63
63
 
64
64
  class Sin(OptimizerModule):
65
65
  """applies sin function to the ascent"""
@@ -67,7 +67,7 @@ class Sin(OptimizerModule):
67
67
  super().__init__({})
68
68
 
69
69
  @torch.no_grad
70
- def _update(self, state, ascent): return ascent.sin_()
70
+ def _update(self, vars, ascent): return ascent.sin_()
71
71
 
72
72
  class Cos(OptimizerModule):
73
73
  """applies cos function to the ascent"""
@@ -75,7 +75,7 @@ class Cos(OptimizerModule):
75
75
  super().__init__({})
76
76
 
77
77
  @torch.no_grad
78
- def _update(self, state, ascent): return ascent.cos_()
78
+ def _update(self, vars, ascent): return ascent.cos_()
79
79
 
80
80
 
81
81
  class NanToNum(OptimizerModule):
@@ -97,7 +97,7 @@ class NanToNum(OptimizerModule):
97
97
  self.neginf = neginf
98
98
 
99
99
  @torch.no_grad()
100
- def _update(self, state, ascent): return ascent.nan_to_num_(self.nan, self.posinf, self.neginf)
100
+ def _update(self, vars, ascent): return ascent.nan_to_num_(self.nan, self.posinf, self.neginf)
101
101
 
102
102
 
103
103
  class MagnitudePower(OptimizerModule):
@@ -107,7 +107,7 @@ class MagnitudePower(OptimizerModule):
107
107
  self.value = value
108
108
 
109
109
  @torch.no_grad()
110
- def _update(self, state, ascent):
110
+ def _update(self, vars, ascent):
111
111
  if self.value % 2 == 1: return ascent.pow_(self.value)
112
112
  return ascent.abs().pow_(self.value) * ascent.sign()
113
113
 
@@ -31,7 +31,7 @@ class Adagrad(OptimizerModule):
31
31
  self.cur_step = 0
32
32
 
33
33
  @torch.no_grad
34
- def _update(self, state, ascent):
34
+ def _update(self, vars, ascent):
35
35
  settings = self.get_all_group_keys()
36
36
  if self.cur_step == 0: init = ascent.full_like(settings['initial_accumulator_value'])
37
37
  else: init = None
@@ -48,7 +48,7 @@ class Adam(OptimizerModule):
48
48
  self.amsgrad = amsgrad
49
49
 
50
50
  @torch.no_grad
51
- def step(self, state):
51
+ def step(self, vars):
52
52
  # Adam step is a bit differet from other optimizer steps
53
53
  # due to how common it is, I implemented two additional optimizations,
54
54
 
@@ -85,14 +85,14 @@ class Adam(OptimizerModule):
85
85
  alpha = settings['alpha']
86
86
 
87
87
  # get params if ascent is None so we need params to access their gradient as initial ascent
88
- if state.ascent is None:
88
+ if vars.ascent is None:
89
89
  if params is None: pg = self.get_params()
90
90
  else: pg = params
91
91
  else:
92
92
  pg = None
93
93
 
94
94
  ret = _adam_step(
95
- ascent=state.maybe_use_grad_(pg),
95
+ ascent=vars.maybe_use_grad_(pg),
96
96
  exp_avg = exp_avg,
97
97
  exp_avg_sq = exp_avg_sq,
98
98
  alpha = alpha,
@@ -107,12 +107,12 @@ class Adam(OptimizerModule):
107
107
  self.cur_step += 1
108
108
  if params is None:
109
109
  assert ret is not None
110
- state.ascent = ret
111
- return self._update_params_or_step_with_next(state)
110
+ vars.ascent = ret
111
+ return self._update_params_or_step_with_next(vars)
112
112
 
113
113
  # next module is either None or LR
114
- if self.next_module is None: return state.get_loss()
114
+ if self.next_module is None: return vars.get_loss()
115
115
 
116
116
  # step with LR, which has _skip = True so it won't apply lr, but may step with the scheduler
117
- self.next_module._update(state, None) # type:ignore
118
- return state.get_loss()
117
+ self.next_module._update(vars, None) # type:ignore
118
+ return vars.get_loss()
@@ -22,7 +22,7 @@ class Lion(OptimizerModule):
22
22
  super().__init__(defaults)
23
23
 
24
24
  @torch.no_grad
25
- def _update(self, state, ascent):
25
+ def _update(self, vars, ascent):
26
26
  beta1, beta2 = self.get_group_keys('beta1', 'beta2')
27
27
  ema = self.get_state_key('ema')
28
28
  return _lion_step_(ascent,ema,beta1,beta2)
@@ -40,7 +40,7 @@ class RMSProp(OptimizerModule):
40
40
  self.centered = centered
41
41
 
42
42
  @torch.no_grad
43
- def _update(self, state, ascent):
43
+ def _update(self, vars, ascent):
44
44
  settings = self.get_all_group_keys()
45
45
  if self.centered:
46
46
  mean, mean_sqr = self.get_state_keys('mean', 'mean_sqr')
@@ -49,7 +49,7 @@ class Rprop(OptimizerModule):
49
49
  self.backtrack = backtrack
50
50
 
51
51
  @torch.no_grad
52
- def _update(self, state, ascent):
52
+ def _update(self, vars, ascent):
53
53
  params = self.get_params()
54
54
 
55
55
  sign = ascent.sign_()
@@ -15,7 +15,7 @@ class SGD(OptimizerModule):
15
15
  weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
16
16
  nesterov (bool, optional):
17
17
  enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
18
- alpha (float, optional): learning rate. Defaults to 1e-3.
18
+ alpha (float, optional): learning rate. Defaults to 1.
19
19
  """
20
20
  def __init__(
21
21
  self,
@@ -32,7 +32,7 @@ class SGD(OptimizerModule):
32
32
  self.current_step = 0
33
33
 
34
34
  @torch.no_grad
35
- def _update(self, state, ascent):
35
+ def _update(self, vars, ascent):
36
36
  params = self.get_params()
37
37
  settings = self.get_all_group_keys()
38
38
 
@@ -116,7 +116,7 @@ class ZeropowerViaNewtonSchulz(OptimizerModule):
116
116
  if compiled: self._zeropower_via_newtonschulz5 = _compiled_zeropower_via_newtonschulz5
117
117
  else: self._zeropower_via_newtonschulz5 = _zeropower_via_newtonschulz5
118
118
 
119
- def _update(self, state, ascent):
119
+ def _update(self, vars, ascent):
120
120
  toggle, ns_steps, adaptive = self.get_group_keys('newtonshultz', 'ns_steps', 'adaptive', cls=list)
121
121
 
122
122
  for asc, enable, steps, ada in zip(ascent, toggle, ns_steps, adaptive):
@@ -146,11 +146,11 @@ class DualNormCorrection(OptimizerModule):
146
146
  defaults = dict(adaptive_scale_min = adaptive_scale_min, adaptive_scale_max = adaptive_scale_max)
147
147
  super().__init__(defaults)
148
148
 
149
- def _update(self, state, ascent):
149
+ def _update(self, vars, ascent):
150
150
  params = self.get_params()
151
151
  adaptive_scale_min, adaptive_scale_max = self.get_group_keys('adaptive_scale_min', 'adaptive_scale_max')
152
152
 
153
- for asc, grad, min, max in zip(ascent, state.maybe_compute_grad_(params), adaptive_scale_min, adaptive_scale_max):
153
+ for asc, grad, min, max in zip(ascent, vars.maybe_compute_grad_(params), adaptive_scale_min, adaptive_scale_max):
154
154
  if len([i for i in asc.shape if i > 1]) != 0:
155
155
  scale = torch.einsum('ij,ij->', grad.view(grad.shape[0], -1), asc.view(asc.shape[0], -1))
156
156
  if min is not None or max is not None: scale = scale.clip(min, max)
@@ -80,7 +80,7 @@ class Orthogonalize(OptimizerModule):
80
80
  super().__init__(defaults, target = target)
81
81
  self.warn_fail = warn_fail
82
82
 
83
- def _update(self, state, ascent):
83
+ def _update(self, vars, ascent):
84
84
  toggle = self.get_group_key('orth', cls=list)
85
85
  _orthogonalize_update_(ascent, toggle, self.warn_fail)
86
86
  return ascent
@@ -27,7 +27,7 @@ class Dropout(OptimizerModule):
27
27
  super().__init__(defaults)
28
28
 
29
29
  @torch.no_grad
30
- def _update(self, state, ascent):
30
+ def _update(self, vars, ascent):
31
31
  p = self.get_group_key('p')
32
32
 
33
33
  ascent *= ascent.bernoulli_like(p)
@@ -18,7 +18,7 @@ def add_noise_(
18
18
  grads += grads.sample_like(alpha, distribution)
19
19
 
20
20
  elif mode == 'global':
21
- grads += grads.sample_like((grads.total_vector_norm(1)/grads.total_numel() * alpha).detach().cpu().item(), distribution)
21
+ grads += grads.sample_like((grads.total_vector_norm(1)/grads.total_numel() * alpha).detach().cpu().item(), distribution) # type:ignore
22
22
 
23
23
  elif mode == 'param':
24
24
  grads += grads.sample_like(grads.abs().mean()*alpha, distribution)
@@ -57,7 +57,7 @@ class AddNoise(OptimizerModule):
57
57
  self.mode: Literal["absolute", "global", "param", "channel"] = mode
58
58
 
59
59
  @torch.no_grad
60
- def _update(self, state, ascent):
60
+ def _update(self, vars, ascent):
61
61
  alpha = self.get_group_key('alpha')
62
62
 
63
63
  add_noise_(ascent, alpha, self.distribution, self.mode)
@@ -72,6 +72,6 @@ class Random(OptimizerModule):
72
72
  self.distribution: Distributions = distribution
73
73
 
74
74
  @torch.no_grad
75
- def _update(self, state, ascent):
75
+ def _update(self, vars, ascent):
76
76
  alpha = self.get_group_key('alpha')
77
77
  return ascent.sample_like(alpha, self.distribution)
@@ -29,7 +29,7 @@ def _normalize_grad_(
29
29
  if not isinstance(grads, TensorList): grads = TensorList(grads)
30
30
  norm = grads.total_vector_norm(ord)
31
31
  if norm > min:
32
- grads /= norm / norm_value
32
+ grads /= norm / norm_value # type:ignore
33
33
 
34
34
  @torch.no_grad
35
35
  def normalize_grad_(
@@ -112,7 +112,7 @@ class Normalize(OptimizerModule):
112
112
  self.min_numel = min_numel
113
113
 
114
114
  @torch.no_grad
115
- def _update(self, state, ascent):
115
+ def _update(self, vars, ascent):
116
116
  _normalize_grad_(
117
117
  ascent,
118
118
  norm_value = self.norm_value,
@@ -225,7 +225,7 @@ class Centralize(OptimizerModule):
225
225
  self.min_numel = min_numel
226
226
 
227
227
  @torch.no_grad
228
- def _update(self, state, ascent):
228
+ def _update(self, vars, ascent):
229
229
  _centralize_grad_(
230
230
  ascent,
231
231
  mode = self.mode,
@@ -258,7 +258,7 @@ class ClipValue(OptimizerModule):
258
258
  super().__init__(defaults)
259
259
 
260
260
  @torch.no_grad
261
- def _update(self, state, ascent):
261
+ def _update(self, vars, ascent):
262
262
  value = self.get_group_key('value')
263
263
  ascent.clamp_(-value, value)
264
264
  return ascent
@@ -317,7 +317,7 @@ class ClipNorm(OptimizerModule):
317
317
  self.mode: typing.Literal["global", "param", "channel"] = mode
318
318
 
319
319
  @torch.no_grad
320
- def _update(self, state, ascent):
320
+ def _update(self, vars, ascent):
321
321
  _normalize_grad_(
322
322
  ascent,
323
323
  norm_value = self.max_norm,
@@ -58,7 +58,7 @@ class OrthoGrad(OptimizerModule):
58
58
  self.renormalize = renormalize
59
59
  self.sqrt_scale = sqrt_scale
60
60
 
61
- def _update(self, state, ascent):
61
+ def _update(self, vars, ascent):
62
62
  params = self.get_params()
63
63
 
64
64
  if self.renormalize: orig_norm = ascent.norm(2) + self.eps
@@ -79,7 +79,7 @@ class WeightDecay(OptimizerModule):
79
79
  self.ord = ord
80
80
 
81
81
  @torch.no_grad
82
- def _update(self, state, ascent):
82
+ def _update(self, vars, ascent):
83
83
  params = self.get_params()
84
84
  alpha = self.get_group_key('alpha')
85
85
 
@@ -81,7 +81,7 @@ if TYPE_CHECKING:
81
81
 
82
82
  # self.id = random.random()
83
83
 
84
- # def step(self, state):
84
+ # def step(self, vars):
85
85
  # if self.cur % self.update_every == 0:
86
86
  # self.scheduler_step_fn()
87
87
  # self.cur_lr = self.dummy_opt.first_param_group['lr']
@@ -113,7 +113,7 @@ class LRWarmup(OptimizerModule):
113
113
 
114
114
  self.cur = 0
115
115
 
116
- def _update(self, state, ascent):
116
+ def _update(self, vars, ascent):
117
117
  if self.cur < self.delay_steps:
118
118
  if self.start_lr != 1: ascent *= self.start_lr
119
119