torchzero 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. torchzero/core/__init__.py +1 -1
  2. torchzero/core/module.py +72 -49
  3. torchzero/core/tensorlist_optimizer.py +1 -1
  4. torchzero/modules/adaptive/adaptive.py +11 -11
  5. torchzero/modules/experimental/experimental.py +41 -41
  6. torchzero/modules/experimental/quad_interp.py +8 -8
  7. torchzero/modules/experimental/subspace.py +37 -37
  8. torchzero/modules/gradient_approximation/base_approximator.py +19 -24
  9. torchzero/modules/gradient_approximation/fdm.py +1 -1
  10. torchzero/modules/gradient_approximation/newton_fdm.py +13 -13
  11. torchzero/modules/gradient_approximation/rfdm.py +1 -1
  12. torchzero/modules/line_search/armijo.py +8 -8
  13. torchzero/modules/line_search/base_ls.py +8 -8
  14. torchzero/modules/line_search/directional_newton.py +14 -14
  15. torchzero/modules/line_search/grid_ls.py +7 -7
  16. torchzero/modules/line_search/scipy_minimize_scalar.py +3 -3
  17. torchzero/modules/meta/alternate.py +4 -4
  18. torchzero/modules/meta/grafting.py +23 -23
  19. torchzero/modules/meta/optimizer_wrapper.py +14 -14
  20. torchzero/modules/meta/return_overrides.py +8 -8
  21. torchzero/modules/misc/accumulate.py +6 -6
  22. torchzero/modules/misc/basic.py +16 -16
  23. torchzero/modules/misc/lr.py +2 -2
  24. torchzero/modules/misc/multistep.py +7 -7
  25. torchzero/modules/misc/on_increase.py +9 -9
  26. torchzero/modules/momentum/momentum.py +4 -4
  27. torchzero/modules/operations/multi.py +44 -44
  28. torchzero/modules/operations/reduction.py +28 -28
  29. torchzero/modules/operations/singular.py +9 -9
  30. torchzero/modules/optimizers/adagrad.py +1 -1
  31. torchzero/modules/optimizers/adam.py +8 -8
  32. torchzero/modules/optimizers/lion.py +1 -1
  33. torchzero/modules/optimizers/rmsprop.py +1 -1
  34. torchzero/modules/optimizers/rprop.py +1 -1
  35. torchzero/modules/optimizers/sgd.py +2 -2
  36. torchzero/modules/orthogonalization/newtonschulz.py +3 -3
  37. torchzero/modules/orthogonalization/svd.py +1 -1
  38. torchzero/modules/regularization/dropout.py +1 -1
  39. torchzero/modules/regularization/noise.py +3 -3
  40. torchzero/modules/regularization/normalization.py +5 -5
  41. torchzero/modules/regularization/ortho_grad.py +1 -1
  42. torchzero/modules/regularization/weight_decay.py +1 -1
  43. torchzero/modules/scheduling/lr_schedulers.py +2 -2
  44. torchzero/modules/scheduling/step_size.py +8 -8
  45. torchzero/modules/second_order/newton.py +12 -12
  46. torchzero/modules/smoothing/__init__.py +1 -1
  47. torchzero/modules/smoothing/gaussian_smoothing.py +7 -7
  48. torchzero/modules/smoothing/laplacian_smoothing.py +1 -1
  49. torchzero/modules/weight_averaging/ema.py +3 -3
  50. torchzero/modules/weight_averaging/swa.py +8 -8
  51. torchzero/optim/first_order/forward_gradient.py +1 -1
  52. torchzero/optim/modular.py +4 -4
  53. torchzero/tensorlist.py +8 -1
  54. {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/METADATA +1 -1
  55. torchzero-0.1.5.dist-info/RECORD +104 -0
  56. torchzero-0.1.3.dist-info/RECORD +0 -104
  57. {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/LICENSE +0 -0
  58. {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/WHEEL +0 -0
  59. {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/top_level.txt +0 -0
@@ -45,15 +45,15 @@ class Graft(OptimizerModule):
45
45
 
46
46
 
47
47
  @torch.no_grad
48
- def step(self, state):
49
- state_copy = state.copy(clone_ascent=True)
48
+ def step(self, vars):
49
+ state_copy = vars.copy(clone_ascent=True)
50
50
  magnitude = self.children['magnitude'].return_ascent(state_copy)
51
51
 
52
- if state_copy.grad is not None: state.grad = state_copy.grad
53
- if state_copy.fx0 is not None: state.fx0 = state_copy.fx0
54
- if state_copy.fx0_approx is not None: state.fx0_approx = state_copy.fx0_approx
52
+ if state_copy.grad is not None: vars.grad = state_copy.grad
53
+ if state_copy.fx0 is not None: vars.fx0 = state_copy.fx0
54
+ if state_copy.fx0_approx is not None: vars.fx0_approx = state_copy.fx0_approx
55
55
 
56
- direction = self.children['direction'].return_ascent(state)
56
+ direction = self.children['direction'].return_ascent(vars)
57
57
 
58
58
  if self.layerwise:
59
59
  M = magnitude.norm(self.ord)
@@ -65,8 +65,8 @@ class Graft(OptimizerModule):
65
65
  D = direction.total_vector_norm(self.ord)
66
66
  if D == 0: D = M
67
67
 
68
- state.ascent = direction.mul_(M / (D + self.eps))
69
- return self._update_params_or_step_with_next(state)
68
+ vars.ascent = direction.mul_(M / (D + self.eps))
69
+ return self._update_params_or_step_with_next(vars)
70
70
 
71
71
 
72
72
 
@@ -94,17 +94,17 @@ class SignGrafting(OptimizerModule):
94
94
 
95
95
 
96
96
  @torch.no_grad
97
- def step(self, state):
98
- state_copy = state.copy(clone_ascent=True)
97
+ def step(self, vars):
98
+ state_copy = vars.copy(clone_ascent=True)
99
99
  magnitude = self.children['magnitude'].return_ascent(state_copy)
100
100
 
101
101
  # make sure to store grad and fx0 if it was calculated
102
- state.update_attrs_(state_copy)
102
+ vars.update_attrs_(state_copy)
103
103
 
104
- sign = self.children['sign'].return_ascent(state)
104
+ sign = self.children['sign'].return_ascent(vars)
105
105
 
106
- state.ascent = magnitude.copysign_(sign)
107
- return self._update_params_or_step_with_next(state)
106
+ vars.ascent = magnitude.copysign_(sign)
107
+ return self._update_params_or_step_with_next(vars)
108
108
 
109
109
 
110
110
  class IntermoduleCautious(OptimizerModule):
@@ -153,17 +153,17 @@ class IntermoduleCautious(OptimizerModule):
153
153
  self.mode: Literal["zero", "grad", "backtrack", "compare_module"] = mode
154
154
 
155
155
  @torch.no_grad
156
- def step(self, state):
156
+ def step(self, vars):
157
157
  params = None
158
- state_copy = state.copy(clone_ascent=True)
158
+ state_copy = vars.copy(clone_ascent=True)
159
159
  ascent = self.children['main'].return_ascent(state_copy)
160
- state.update_attrs_(state_copy)
160
+ vars.update_attrs_(state_copy)
161
161
 
162
- if self.compare_mode == 'module': compare = self.children['compare'].return_ascent(state)
162
+ if self.compare_mode == 'module': compare = self.children['compare'].return_ascent(vars)
163
163
  else:
164
164
  params = self.get_params()
165
- if self.compare_mode == 'ascent': compare: TensorList = state.maybe_use_grad_(params)
166
- elif self.compare_mode == 'grad': compare: TensorList = state.maybe_compute_grad_(params)
165
+ if self.compare_mode == 'ascent': compare: TensorList = vars.maybe_use_grad_(params)
166
+ elif self.compare_mode == 'grad': compare: TensorList = vars.maybe_compute_grad_(params)
167
167
  else: raise ValueError(f'Invalid compare_module: {self.compare_mode}')
168
168
 
169
169
  # mask will be > 0 for parameters where both signs are the same
@@ -185,11 +185,11 @@ class IntermoduleCautious(OptimizerModule):
185
185
 
186
186
  if self.mode == 'grad':
187
187
  params = self.get_params()
188
- ascent += state.maybe_compute_grad_(params) * mask.logical_not_()
188
+ ascent += vars.maybe_compute_grad_(params) * mask.logical_not_()
189
189
 
190
190
  elif self.mode == 'compare_module':
191
191
  ascent += compare * mask.logical_not_()
192
192
 
193
- state.ascent = ascent
194
- return self._update_params_or_step_with_next(state, params)
193
+ vars.ascent = ascent
194
+ return self._update_params_or_step_with_next(vars, params)
195
195
 
@@ -64,7 +64,7 @@ class Wrap(OptimizerModule):
64
64
  self.optimizer = self._optimizer_cls(params, *self._args, **self._kwargs)
65
65
 
66
66
  @torch.no_grad
67
- def step(self, state):
67
+ def step(self, vars):
68
68
  # check attrs
69
69
  # if self.pass_closure:
70
70
  # if state.closure is None: raise ValueError('ClosureOptimizerWrapper requires closure.')
@@ -75,22 +75,22 @@ class Wrap(OptimizerModule):
75
75
 
76
76
  if self.next_module is None:
77
77
  # set grad to ascent and make a step with the optimizer
78
- g = state.maybe_use_grad_(params)
78
+ g = vars.maybe_use_grad_(params)
79
79
  params.set_grad_(g)
80
- state.fx0 = self.optimizer.step()
81
- return state.get_loss()
80
+ vars.fx0 = self.optimizer.step()
81
+ return vars.get_loss()
82
82
 
83
83
 
84
84
  params_before_step = params.clone()
85
85
 
86
- g = state.maybe_use_grad_(params)
86
+ g = vars.maybe_use_grad_(params)
87
87
  params.set_grad_(g)
88
- state.fx0 = self.optimizer.step()
88
+ vars.fx0 = self.optimizer.step()
89
89
 
90
90
  # calculate update as difference in params
91
- state.ascent = params_before_step - params
91
+ vars.ascent = params_before_step - params
92
92
  params.set_(params_before_step)
93
- return self.next_module.step(state)
93
+ return self.next_module.step(vars)
94
94
 
95
95
 
96
96
  class WrapClosure(OptimizerModule):
@@ -148,7 +148,7 @@ class WrapClosure(OptimizerModule):
148
148
  self.optimizer = self._optimizer_cls(params, *self._args, **self._kwargs)
149
149
 
150
150
  @torch.no_grad
151
- def step(self, state):
151
+ def step(self, vars):
152
152
  # check attrs
153
153
  # if self.pass_closure:
154
154
  # if state.closure is None: raise ValueError('ClosureOptimizerWrapper requires closure.')
@@ -159,15 +159,15 @@ class WrapClosure(OptimizerModule):
159
159
 
160
160
  if self.next_module is None:
161
161
  # set grad to ascent and make a step with the optimizer
162
- state.fx0 = self.optimizer.step(state.closure) # type:ignore
163
- return state.get_loss()
162
+ vars.fx0 = self.optimizer.step(vars.closure) # type:ignore
163
+ return vars.get_loss()
164
164
 
165
165
 
166
166
  params_before_step = params.clone()
167
- state.fx0 = self.optimizer.step(state.closure) # type:ignore
167
+ vars.fx0 = self.optimizer.step(vars.closure) # type:ignore
168
168
 
169
169
  # calculate update as difference in params
170
- state.ascent = params_before_step - params
170
+ vars.ascent = params_before_step - params
171
171
  params.set_(params_before_step)
172
- return self.next_module.step(state)
172
+ return self.next_module.step(vars)
173
173
 
@@ -9,12 +9,12 @@ class SetGrad(OptimizerModule):
9
9
  super().__init__({})
10
10
 
11
11
  @torch.no_grad
12
- def step(self, state):
12
+ def step(self, vars):
13
13
  if self.next_module is not None: raise ValueError("SetGrad can't have children")
14
14
  params = self.get_params()
15
- g = state.maybe_use_grad_(params) # this may execute the closure which might be modified
15
+ g = vars.maybe_use_grad_(params) # this may execute the closure which might be modified
16
16
  params.set_grad_(g)
17
- return state.get_loss()
17
+ return vars.get_loss()
18
18
 
19
19
 
20
20
  class ReturnAscent(OptimizerModule):
@@ -23,10 +23,10 @@ class ReturnAscent(OptimizerModule):
23
23
  super().__init__({})
24
24
 
25
25
  @torch.no_grad
26
- def step(self, state) -> TensorList: # type:ignore
26
+ def step(self, vars) -> TensorList: # type:ignore
27
27
  if self.next_module is not None: raise ValueError("ReturnAscent can't have children")
28
28
  params = self.get_params()
29
- update = state.maybe_use_grad_(params) # this will execute the closure which might be modified
29
+ update = vars.maybe_use_grad_(params) # this will execute the closure which might be modified
30
30
  return update
31
31
 
32
32
  class ReturnClosure(OptimizerModule):
@@ -38,9 +38,9 @@ class ReturnClosure(OptimizerModule):
38
38
  super().__init__({})
39
39
 
40
40
  @torch.no_grad
41
- def step(self, state) -> _ClosureType: # type:ignore
41
+ def step(self, vars) -> _ClosureType: # type:ignore
42
42
  if self.next_module is not None: raise ValueError("ReturnClosure can't have children")
43
- if state.closure is None:
43
+ if vars.closure is None:
44
44
  raise ValueError("MakeClosure requires closure")
45
- return state.closure
45
+ return vars.closure
46
46
 
@@ -26,18 +26,18 @@ class Accumulate(OptimizerModule):
26
26
  self.cur_step = 0
27
27
 
28
28
  @torch.no_grad
29
- def step(self, state):
29
+ def step(self, vars):
30
30
  self.cur_step += 1
31
31
 
32
32
  params = self.get_params()
33
33
  accumulated_update = self.get_state_key('accumulated_grads')
34
- accumulated_update += state.maybe_use_grad_(params)
34
+ accumulated_update += vars.maybe_use_grad_(params)
35
35
 
36
36
  if self.cur_step % self.n_steps == 0:
37
- state.ascent = accumulated_update.clone()
38
- if self.mean: state.ascent /= self.n_steps
37
+ vars.ascent = accumulated_update.clone()
38
+ if self.mean: vars.ascent /= self.n_steps
39
39
  accumulated_update.zero_()
40
- return self._update_params_or_step_with_next(state)
40
+ return self._update_params_or_step_with_next(vars)
41
41
 
42
42
 
43
- return state.get_loss()
43
+ return vars.get_loss()
@@ -14,7 +14,7 @@ class Alpha(OptimizerModule):
14
14
  super().__init__(defaults)
15
15
 
16
16
  @torch.no_grad
17
- def _update(self, state, ascent):
17
+ def _update(self, vars, ascent):
18
18
  # multiply ascent direction by lr in-place
19
19
  lr = self.get_group_key('alpha')
20
20
  ascent *= lr
@@ -27,7 +27,7 @@ class Clone(OptimizerModule):
27
27
  super().__init__({})
28
28
 
29
29
  @torch.no_grad
30
- def _update(self, state, ascent): return ascent.clone()
30
+ def _update(self, vars, ascent): return ascent.clone()
31
31
 
32
32
  class Identity(OptimizerModule):
33
33
  """Does nothing."""
@@ -35,7 +35,7 @@ class Identity(OptimizerModule):
35
35
  super().__init__({})
36
36
 
37
37
  @torch.no_grad
38
- def _update(self, state, ascent): return ascent
38
+ def _update(self, vars, ascent): return ascent
39
39
 
40
40
  class Lambda(OptimizerModule):
41
41
  """Applies a function to the ascent direction.
@@ -49,7 +49,7 @@ class Lambda(OptimizerModule):
49
49
  self.f = f
50
50
 
51
51
  @torch.no_grad()
52
- def _update(self, state, ascent): return self.f(ascent)
52
+ def _update(self, vars, ascent): return self.f(ascent)
53
53
 
54
54
  class Grad(OptimizerModule):
55
55
  """Uses gradient as the update. This is useful for chains."""
@@ -57,8 +57,8 @@ class Grad(OptimizerModule):
57
57
  super().__init__({})
58
58
 
59
59
  @torch.no_grad
60
- def _update(self, state, ascent):
61
- ascent = state.ascent = state.maybe_compute_grad_(self.get_params())
60
+ def _update(self, vars, ascent):
61
+ ascent = vars.ascent = vars.maybe_compute_grad_(self.get_params())
62
62
  return ascent
63
63
 
64
64
  class Zeros(OptimizerModule):
@@ -66,7 +66,7 @@ class Zeros(OptimizerModule):
66
66
  super().__init__({})
67
67
 
68
68
  @torch.no_grad
69
- def _update(self, state, ascent):
69
+ def _update(self, vars, ascent):
70
70
  return ascent.zeros_like()
71
71
 
72
72
  class Fill(OptimizerModule):
@@ -74,7 +74,7 @@ class Fill(OptimizerModule):
74
74
  super().__init__({"value": value})
75
75
 
76
76
  @torch.no_grad
77
- def _update(self, state, ascent):
77
+ def _update(self, vars, ascent):
78
78
  return ascent.fill(self.get_group_key('value'))
79
79
 
80
80
 
@@ -83,8 +83,8 @@ class GradToUpdate(OptimizerModule):
83
83
  def __init__(self):
84
84
  super().__init__({})
85
85
 
86
- def _update(self, state, ascent):
87
- state.set_grad_(ascent, self.get_params())
86
+ def _update(self, vars, ascent):
87
+ vars.set_grad_(ascent, self.get_params())
88
88
  return ascent
89
89
 
90
90
  class MakeClosure(OptimizerModule):
@@ -93,12 +93,12 @@ class MakeClosure(OptimizerModule):
93
93
  super().__init__({})
94
94
  self._set_child_('modules', modules)
95
95
 
96
- def step(self, state):
97
- if state.closure is None: raise ValueError("MakeClosure requires a closure")
96
+ def step(self, vars):
97
+ if vars.closure is None: raise ValueError("MakeClosure requires a closure")
98
98
 
99
99
  params = self.get_params()
100
- orig_closure = state.closure
101
- orig_state = state.copy(True)
100
+ orig_closure = vars.closure
101
+ orig_state = vars.copy(True)
102
102
 
103
103
  def new_closure(backward = True):
104
104
  if backward:
@@ -110,6 +110,6 @@ class MakeClosure(OptimizerModule):
110
110
  else:
111
111
  return orig_closure(False)
112
112
 
113
- state.closure = new_closure # type:ignore
114
- return self._update_params_or_step_with_next(state)
113
+ vars.closure = new_closure # type:ignore
114
+ return self._update_params_or_step_with_next(vars)
115
115
 
@@ -71,7 +71,7 @@ class LR(OptimizerModule):
71
71
  self._skip = False
72
72
 
73
73
  @torch.no_grad
74
- def _update(self, state, ascent):
74
+ def _update(self, vars, ascent):
75
75
  # step with scheduler
76
76
  if self._scheduler_step_fn is not None:
77
77
  if self.cur != 0 and self.cur % self.sheduler_step_every == 0:
@@ -79,7 +79,7 @@ class LR(OptimizerModule):
79
79
 
80
80
  # add a hook to cycle momentum
81
81
  if self.cycle_momentum:
82
- state.add_post_step_hook(_set_momentum_hook)
82
+ vars.add_post_step_hook(_set_momentum_hook)
83
83
 
84
84
  # remove init hook to delete reference to scheduler
85
85
  if self.cur == 0 and len(self.post_init_hooks) == 1:
@@ -20,32 +20,32 @@ class Multistep(OptimizerModule):
20
20
 
21
21
  self._set_child_('modules', modules)
22
22
 
23
- def step(self, state):
23
+ def step(self, vars):
24
24
  # no next module, just perform multiple steps
25
25
  if self.next_module is None:
26
26
  ret = None
27
27
  for step in range(self.num_steps):
28
- state_copy = state.copy(clone_ascent=True) if step != self.num_steps - 1 else state
28
+ state_copy = vars.copy(clone_ascent=True) if step != self.num_steps - 1 else vars
29
29
  ret = self.children['modules'].step(state_copy)
30
30
 
31
31
  # since parameters are updated after stepping, grad and fx0 must be erased as they are no longer correct
32
- state.grad = None; state.fx0 = None
32
+ vars.grad = None; vars.fx0 = None
33
33
 
34
34
  return ret
35
35
 
36
36
  # accumulate steps and pass to next module
37
37
  p0 = self.get_params().clone()
38
38
  for step in range(self.num_steps):
39
- state_copy = state.copy(clone_ascent=True) if step != self.num_steps - 1 else state
39
+ state_copy = vars.copy(clone_ascent=True) if step != self.num_steps - 1 else vars
40
40
  self.children['modules'].step(state_copy)
41
41
 
42
42
  # since parameters are updated after stepping, grad and fx0 must be erased as they are no longer correct
43
- state.grad = None; state.fx0 = None
43
+ vars.grad = None; vars.fx0 = None
44
44
 
45
45
  p1 = self.get_params()
46
- state.ascent = p0 - p1
46
+ vars.ascent = p0 - p1
47
47
 
48
48
  # undo ascent
49
49
  p1.set_(p0)
50
50
 
51
- return self._update_params_or_step_with_next(state, p1)
51
+ return self._update_params_or_step_with_next(vars, p1)
@@ -16,38 +16,38 @@ class NegateOnLossIncrease(OptimizerModule):
16
16
  self.backtrack = backtrack
17
17
 
18
18
  @torch.no_grad()
19
- def step(self, state):
20
- if state.closure is None: raise ValueError('NegateOnLossIncrease requires closure.')
21
- if state.fx0 is None: state.fx0 = state.closure(False)
19
+ def step(self, vars):
20
+ if vars.closure is None: raise ValueError('NegateOnLossIncrease requires closure.')
21
+ if vars.fx0 is None: vars.fx0 = vars.closure(False)
22
22
 
23
23
  # subtract ascent direction to params and see if loss decreases
24
24
  params = self.get_params()
25
- ascent_direction = state.maybe_use_grad_(params)
25
+ ascent_direction = vars.maybe_use_grad_(params)
26
26
  params -= ascent_direction
27
- state.fx0_approx = state.closure(False)
27
+ vars.fx0_approx = vars.closure(False)
28
28
 
29
29
  # if this has no children, update params and return loss
30
30
  if self.next_module is None:
31
31
  if params is None: params = self.get_params()
32
32
 
33
- if state.fx0_approx > state.fx0:
33
+ if vars.fx0_approx > vars.fx0:
34
34
  # loss increased, so we negate thea scent direction
35
35
  # we are currently at params - ascent direction
36
36
  # so we add twice the ascent direction
37
37
  params.add_(ascent_direction, alpha = 2 if self.backtrack else 1)
38
38
 
39
39
  # else: we are already at a lower loss point
40
- return state.get_loss()
40
+ return vars.get_loss()
41
41
 
42
42
  # otherwise undo the ascent direction because it is passed to the child
43
43
  params += ascent_direction
44
44
 
45
45
  # if loss increases, negate ascent direction
46
- if state.fx0_approx > state.fx0:
46
+ if vars.fx0_approx > vars.fx0:
47
47
  if self.backtrack: ascent_direction.neg_()
48
48
  else: ascent_direction.zero_()
49
49
 
50
50
  # otherwise undo the ascent direction and pass the updated ascent direction to the child
51
- return self.next_module.step(state)
51
+ return self.next_module.step(vars)
52
52
 
53
53
 
@@ -21,7 +21,7 @@ class HeavyBall(OptimizerModule):
21
21
  super().__init__(defaults)
22
22
 
23
23
  @torch.no_grad
24
- def _update(self, state, ascent):
24
+ def _update(self, vars, ascent):
25
25
  velocity = self.get_state_key('velocity', init = ascent)
26
26
  settings = self.get_all_group_keys()
27
27
  updated_direction = _heavyball_step(ascent, velocity, settings['momentum'], settings['dampening'])
@@ -52,7 +52,7 @@ class NesterovMomentum(OptimizerModule):
52
52
  super().__init__(defaults)
53
53
 
54
54
  @torch.no_grad
55
- def _update(self, state, ascent):
55
+ def _update(self, vars, ascent):
56
56
  velocity = self.get_state_key('velocity')
57
57
  settings = self.get_all_group_keys()
58
58
  _nesterov_step_(ascent, velocity, settings['momentum'], settings['dampening'])
@@ -65,7 +65,7 @@ class GradientAveraging(OptimizerModule):
65
65
  super().__init__(defaults)
66
66
 
67
67
  @torch.no_grad
68
- def _update(self, state, ascent):
68
+ def _update(self, vars, ascent):
69
69
  velocity = self.get_state_key('velocity')
70
70
  dampening = self.get_group_key('dampening')
71
71
 
@@ -89,7 +89,7 @@ class RandomCoordinateMomentum(OptimizerModule):
89
89
  self.nesterov = nesterov
90
90
 
91
91
  @torch.no_grad
92
- def _update(self, state, ascent):
92
+ def _update(self, vars, ascent):
93
93
  velocity = self.get_state_key('velocity', init = ascent)
94
94
  settings = self.get_all_group_keys()
95
95