torchzero 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. torchzero/core/__init__.py +1 -1
  2. torchzero/core/module.py +72 -49
  3. torchzero/core/tensorlist_optimizer.py +1 -1
  4. torchzero/modules/adaptive/adaptive.py +11 -11
  5. torchzero/modules/experimental/experimental.py +41 -41
  6. torchzero/modules/experimental/quad_interp.py +8 -8
  7. torchzero/modules/experimental/subspace.py +37 -37
  8. torchzero/modules/gradient_approximation/base_approximator.py +19 -24
  9. torchzero/modules/gradient_approximation/fdm.py +1 -1
  10. torchzero/modules/gradient_approximation/newton_fdm.py +13 -13
  11. torchzero/modules/gradient_approximation/rfdm.py +1 -1
  12. torchzero/modules/line_search/armijo.py +8 -8
  13. torchzero/modules/line_search/base_ls.py +8 -8
  14. torchzero/modules/line_search/directional_newton.py +14 -14
  15. torchzero/modules/line_search/grid_ls.py +7 -7
  16. torchzero/modules/line_search/scipy_minimize_scalar.py +3 -3
  17. torchzero/modules/meta/alternate.py +4 -4
  18. torchzero/modules/meta/grafting.py +23 -23
  19. torchzero/modules/meta/optimizer_wrapper.py +14 -14
  20. torchzero/modules/meta/return_overrides.py +8 -8
  21. torchzero/modules/misc/accumulate.py +6 -6
  22. torchzero/modules/misc/basic.py +16 -16
  23. torchzero/modules/misc/lr.py +2 -2
  24. torchzero/modules/misc/multistep.py +7 -7
  25. torchzero/modules/misc/on_increase.py +9 -9
  26. torchzero/modules/momentum/momentum.py +4 -4
  27. torchzero/modules/operations/multi.py +44 -44
  28. torchzero/modules/operations/reduction.py +28 -28
  29. torchzero/modules/operations/singular.py +9 -9
  30. torchzero/modules/optimizers/adagrad.py +1 -1
  31. torchzero/modules/optimizers/adam.py +8 -8
  32. torchzero/modules/optimizers/lion.py +1 -1
  33. torchzero/modules/optimizers/rmsprop.py +1 -1
  34. torchzero/modules/optimizers/rprop.py +1 -1
  35. torchzero/modules/optimizers/sgd.py +2 -2
  36. torchzero/modules/orthogonalization/newtonschulz.py +3 -3
  37. torchzero/modules/orthogonalization/svd.py +1 -1
  38. torchzero/modules/regularization/dropout.py +1 -1
  39. torchzero/modules/regularization/noise.py +3 -3
  40. torchzero/modules/regularization/normalization.py +5 -5
  41. torchzero/modules/regularization/ortho_grad.py +1 -1
  42. torchzero/modules/regularization/weight_decay.py +1 -1
  43. torchzero/modules/scheduling/lr_schedulers.py +2 -2
  44. torchzero/modules/scheduling/step_size.py +8 -8
  45. torchzero/modules/second_order/newton.py +12 -12
  46. torchzero/modules/smoothing/__init__.py +1 -1
  47. torchzero/modules/smoothing/gaussian_smoothing.py +7 -7
  48. torchzero/modules/smoothing/laplacian_smoothing.py +1 -1
  49. torchzero/modules/weight_averaging/ema.py +3 -3
  50. torchzero/modules/weight_averaging/swa.py +8 -8
  51. torchzero/optim/first_order/forward_gradient.py +1 -1
  52. torchzero/optim/modular.py +4 -4
  53. torchzero/tensorlist.py +8 -1
  54. {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/METADATA +1 -1
  55. torchzero-0.1.5.dist-info/RECORD +104 -0
  56. torchzero-0.1.3.dist-info/RECORD +0 -104
  57. {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/LICENSE +0 -0
  58. {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/WHEEL +0 -0
  59. {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/top_level.txt +0 -0
@@ -5,14 +5,14 @@ from collections import abc
5
5
  import torch
6
6
 
7
7
  from ... import tensorlist as tl
8
- from ...core import OptimizationState, OptimizerModule, _Chain, _maybe_pass_backward
8
+ from ...core import OptimizationVars, OptimizerModule, _Chain, _maybe_pass_backward
9
9
  # this whole thing can also be implemented via parameter vectors.
10
10
  # Need to test which one is more efficient...
11
11
 
12
12
  class Projection(ABC):
13
13
  n = 1
14
14
  @abstractmethod
15
- def sample(self, params: tl.TensorList, state: OptimizationState) -> list[tl.TensorList]:
15
+ def sample(self, params: tl.TensorList, vars: OptimizationVars) -> list[tl.TensorList]:
16
16
  """Generate a projection.
17
17
 
18
18
  Args:
@@ -28,7 +28,7 @@ class ProjRandom(Projection):
28
28
  self.distribution: tl.Distributions = distribution
29
29
  self.n = n
30
30
 
31
- def sample(self, params: tl.TensorList, state: OptimizationState):
31
+ def sample(self, params: tl.TensorList, vars: OptimizationVars):
32
32
  return [params.sample_like(distribution=self.distribution) for _ in range(self.n)]
33
33
 
34
34
 
@@ -42,7 +42,7 @@ class Proj2Masks(Projection):
42
42
  def n(self):
43
43
  return self.n_pairs * 2
44
44
 
45
- def sample(self, params: tl.TensorList, state: OptimizationState):
45
+ def sample(self, params: tl.TensorList, vars: OptimizationVars):
46
46
  projections = []
47
47
  for i in range(self.n_pairs):
48
48
  mask = params.bernoulli_like(0.5)
@@ -55,9 +55,9 @@ class Proj2Masks(Projection):
55
55
 
56
56
  class ProjAscent(Projection):
57
57
  """Use ascent direction as the projection."""
58
- def sample(self, params: tl.TensorList, state: OptimizationState):
59
- if state.ascent is None: raise ValueError
60
- return [state.ascent]
58
+ def sample(self, params: tl.TensorList, vars: OptimizationVars):
59
+ if vars.ascent is None: raise ValueError
60
+ return [vars.ascent]
61
61
 
62
62
  class ProjAscentRay(Projection):
63
63
  def __init__(self, eps = 0.1, n = 1, distribution: tl.Distributions = 'normal', ):
@@ -65,14 +65,14 @@ class ProjAscentRay(Projection):
65
65
  self.distribution: tl.Distributions = distribution
66
66
  self.n = n
67
67
 
68
- def sample(self, params: tl.TensorList, state: OptimizationState):
69
- if state.ascent is None: raise ValueError
68
+ def sample(self, params: tl.TensorList, vars: OptimizationVars):
69
+ if vars.ascent is None: raise ValueError
70
70
  mean = params.total_mean().detach().cpu().item()
71
- return [state.ascent + state.ascent.sample_like(mean * self.eps, distribution=self.distribution) for _ in range(self.n)]
71
+ return [vars.ascent + vars.ascent.sample_like(mean * self.eps, distribution=self.distribution) for _ in range(self.n)]
72
72
 
73
73
  class ProjGrad(Projection):
74
- def sample(self, params: tl.TensorList, state: OptimizationState):
75
- grad = state.maybe_compute_grad_(params)
74
+ def sample(self, params: tl.TensorList, vars: OptimizationVars):
75
+ grad = vars.maybe_compute_grad_(params)
76
76
  return [grad]
77
77
 
78
78
  class ProjGradRay(Projection):
@@ -81,8 +81,8 @@ class ProjGradRay(Projection):
81
81
  self.distribution: tl.Distributions = distribution
82
82
  self.n = n
83
83
 
84
- def sample(self, params: tl.TensorList, state: OptimizationState):
85
- grad = state.maybe_compute_grad_(params)
84
+ def sample(self, params: tl.TensorList, vars: OptimizationVars):
85
+ grad = vars.maybe_compute_grad_(params)
86
86
  mean = params.total_mean().detach().cpu().item()
87
87
  return [grad + grad.sample_like(mean * self.eps, distribution=self.distribution) for _ in range(self.n)]
88
88
 
@@ -95,23 +95,23 @@ class ProjGradAscentDifference(Projection):
95
95
  """
96
96
  self.normalize = normalize
97
97
 
98
- def sample(self, params: tl.TensorList, state: OptimizationState):
99
- grad = state.maybe_compute_grad_(params)
98
+ def sample(self, params: tl.TensorList, vars: OptimizationVars):
99
+ grad = vars.maybe_compute_grad_(params)
100
100
  if self.normalize:
101
- return [state.ascent / state.ascent.total_vector_norm(2) - grad / grad.total_vector_norm(2)] # type:ignore
101
+ return [vars.ascent / vars.ascent.total_vector_norm(2) - grad / grad.total_vector_norm(2)] # type:ignore
102
102
 
103
- return [state.ascent - grad] # type:ignore
103
+ return [vars.ascent - grad] # type:ignore
104
104
 
105
105
  class ProjLastGradDifference(Projection):
106
106
  def __init__(self):
107
107
  """Use difference between last two gradients as the projection."""
108
108
  self.last_grad = None
109
- def sample(self, params: tl.TensorList, state: OptimizationState):
109
+ def sample(self, params: tl.TensorList, vars: OptimizationVars):
110
110
  if self.last_grad is None:
111
- self.last_grad = state.maybe_compute_grad_(params)
111
+ self.last_grad = vars.maybe_compute_grad_(params)
112
112
  return [self.last_grad]
113
113
 
114
- grad = state.maybe_compute_grad_(params)
114
+ grad = vars.maybe_compute_grad_(params)
115
115
  diff = grad - self.last_grad
116
116
  self.last_grad = grad
117
117
  return [diff]
@@ -121,13 +121,13 @@ class ProjLastAscentDifference(Projection):
121
121
  """Use difference between last two ascent directions as the projection."""
122
122
  self.last_direction = T.cast(tl.TensorList, None)
123
123
 
124
- def sample(self, params: tl.TensorList, state: OptimizationState):
124
+ def sample(self, params: tl.TensorList, vars: OptimizationVars):
125
125
  if self.last_direction is None:
126
- self.last_direction: tl.TensorList = state.ascent # type:ignore
126
+ self.last_direction: tl.TensorList = vars.ascent # type:ignore
127
127
  return [self.last_direction]
128
128
 
129
- diff = state.ascent - self.last_direction # type:ignore
130
- self.last_direction = state.ascent # type:ignore
129
+ diff = vars.ascent - self.last_direction # type:ignore
130
+ self.last_direction = vars.ascent # type:ignore
131
131
  return [diff]
132
132
 
133
133
  class ProjNormalize(Projection):
@@ -139,10 +139,10 @@ class ProjNormalize(Projection):
139
139
  def n(self):
140
140
  return sum(proj.n for proj in self.projections)
141
141
 
142
- def sample(self, params: tl.TensorList, state: OptimizationState): # type:ignore
143
- vecs = [proj for obj in self.projections for proj in obj.sample(params, state)]
142
+ def sample(self, params: tl.TensorList, vars: OptimizationVars): # type:ignore
143
+ vecs = [proj for obj in self.projections for proj in obj.sample(params, vars)]
144
144
  norms = [v.total_vector_norm(2) for v in vecs]
145
- return [v/norm if norm!=0 else v.randn_like() for v,norm in zip(vecs,norms)]
145
+ return [v/norm if norm!=0 else v.randn_like() for v,norm in zip(vecs,norms)] # type:ignore
146
146
 
147
147
  class Subspace(OptimizerModule):
148
148
  """This is pretty inefficient, I thought of a much better way to do this via jvp and I will rewrite this soon.
@@ -198,17 +198,17 @@ class Subspace(OptimizerModule):
198
198
  child.add_param_group({"params": params})
199
199
 
200
200
  @torch.no_grad
201
- def step(self, state):
201
+ def step(self, vars):
202
202
  #if self.next_module is None: raise ValueError('RandomProjection needs a child')
203
- if state.closure is None: raise ValueError('RandomProjection needs a closure')
204
- closure = state.closure
203
+ if vars.closure is None: raise ValueError('RandomProjection needs a closure')
204
+ closure = vars.closure
205
205
  params = self.get_params()
206
206
 
207
207
  # every `regenerate_every` steps we generate new random projections.
208
208
  if self.current_step == 0 or (self.update_every is not None and self.current_step % self.update_every == 0):
209
209
 
210
210
  # generate n projection vetors
211
- self.projection_vectors = [sample for proj in self.projections for sample in proj.sample(params, state)]
211
+ self.projection_vectors = [sample for proj in self.projections for sample in proj.sample(params, vars)]
212
212
 
213
213
  # child params is n scalars corresponding to each projection vector
214
214
  self.projected_params = self.children['subspace']._params[0] # type:ignore
@@ -235,7 +235,7 @@ class Subspace(OptimizerModule):
235
235
  # ascent_direction = tl.sum([ascent_direction*v for v in self.projection_vectors])
236
236
 
237
237
  # perform a step with the child
238
- subspace_state = state.copy(False)
238
+ subspace_state = vars.copy(False)
239
239
  subspace_state.closure = projected_closure
240
240
  subspace_state.ascent = None
241
241
  if subspace_state.grad is not None:
@@ -244,11 +244,11 @@ class Subspace(OptimizerModule):
244
244
 
245
245
  # that is going to update child's paramers, which we now project back to the full parameter space
246
246
  residual = tl.sum([vec * p for vec, p in zip(self.projection_vectors, self.projected_params)])
247
- state.ascent = residual.neg_()
247
+ vars.ascent = residual.neg_()
248
248
 
249
249
  # move fx0 and fx0 approx to state
250
- if subspace_state.fx0 is not None: state.fx0 = subspace_state.fx0
251
- if subspace_state.fx0_approx is not None: state.fx0 = subspace_state.fx0_approx
250
+ if subspace_state.fx0 is not None: vars.fx0 = subspace_state.fx0
251
+ if subspace_state.fx0_approx is not None: vars.fx0 = subspace_state.fx0_approx
252
252
  # projected_params are residuals that have been applied to actual params on previous step in some way
253
253
  # therefore they need to now become zero (otherwise they work like momentum with no decay).
254
254
  # note: THIS WON'T WORK WITH INTEGRATIONS, UNLESS THEY PERFORM FULL MINIMIZATION EACH STEP
@@ -256,4 +256,4 @@ class Subspace(OptimizerModule):
256
256
  self.projected_params.zero_()
257
257
 
258
258
  self.current_step += 1
259
- return self._update_params_or_step_with_next(state)
259
+ return self._update_params_or_step_with_next(vars)
@@ -5,7 +5,7 @@ from typing import Any, Literal
5
5
  import torch
6
6
 
7
7
  from ...core import (
8
- OptimizationState,
8
+ OptimizationVars,
9
9
  OptimizerModule,
10
10
  _ClosureType,
11
11
  _maybe_pass_backward,
@@ -39,12 +39,12 @@ class GradientApproximatorBase(OptimizerModule, ABC):
39
39
  super().__init__(defaults, target)
40
40
  self.requires_fx0 = requires_fx0
41
41
 
42
- def _step_make_closure_(self, state: OptimizationState, params: TensorList):
43
- if state.closure is None: raise ValueError("gradient approximation requires closure")
44
- closure = state.closure
42
+ def _step_make_closure_(self, vars: OptimizationVars, params: TensorList):
43
+ if vars.closure is None: raise ValueError("gradient approximation requires closure")
44
+ closure = vars.closure
45
45
 
46
- if self.requires_fx0: fx0 = state.evaluate_fx0_(False)
47
- else: fx0 = state.fx0
46
+ if self.requires_fx0: fx0 = vars.evaluate_fx0_(False)
47
+ else: fx0 = vars.fx0
48
48
 
49
49
  def new_closure(backward=True) -> _ScalarLoss:
50
50
  if backward:
@@ -56,35 +56,35 @@ class GradientApproximatorBase(OptimizerModule, ABC):
56
56
 
57
57
  return closure(False)
58
58
 
59
- state.closure = new_closure
59
+ vars.closure = new_closure
60
60
 
61
- def _step_make_target_(self, state: OptimizationState, params: TensorList):
62
- if state.closure is None: raise ValueError("gradient approximation requires closure")
61
+ def _step_make_target_(self, vars: OptimizationVars, params: TensorList):
62
+ if vars.closure is None: raise ValueError("gradient approximation requires closure")
63
63
 
64
- if self.requires_fx0: fx0 = state.evaluate_fx0_(False)
65
- else: fx0 = state.fx0
64
+ if self.requires_fx0: fx0 = vars.evaluate_fx0_(False)
65
+ else: fx0 = vars.fx0
66
66
 
67
- g, state.fx0, state.fx0_approx = self._make_ascent(state.closure, params, fx0)
68
- if self._default_step_target == 'ascent': state.ascent = g
69
- elif self._default_step_target == 'grad': state.set_grad_(g, params)
67
+ g, vars.fx0, vars.fx0_approx = self._make_ascent(vars.closure, params, fx0)
68
+ if self._default_step_target == 'ascent': vars.ascent = g
69
+ elif self._default_step_target == 'grad': vars.set_grad_(g, params)
70
70
  else: raise ValueError(f"Unknown target {self._default_step_target}")
71
71
 
72
72
  @torch.no_grad
73
- def step(self, state: OptimizationState):
73
+ def step(self, vars: OptimizationVars):
74
74
  params = self.get_params()
75
75
  if self._default_step_target == 'closure':
76
- self._step_make_closure_(state, params)
76
+ self._step_make_closure_(vars, params)
77
77
 
78
78
  else:
79
- self._step_make_target_(state, params)
79
+ self._step_make_target_(vars, params)
80
80
 
81
- return self._update_params_or_step_with_next(state, params)
81
+ return self._update_params_or_step_with_next(vars, params)
82
82
 
83
83
  @abstractmethod
84
84
  @torch.no_grad
85
85
  def _make_ascent(
86
86
  self,
87
- # state: OptimizationState,
87
+ # vars: OptimizationVars,
88
88
  closure: _ClosureType,
89
89
  params: TensorList,
90
90
  fx0: Any,
@@ -95,11 +95,6 @@ class GradientApproximatorBase(OptimizerModule, ABC):
95
95
 
96
96
  (ascent, fx0, fx0_approx)
97
97
 
98
- :code:`ascent` is the approximated gradient,
99
- :code:`fx0` is loss value strictly with initial parameters of the current step,
100
- :code:`fx0_approx` is loss value with perturbed parameters (will be returned by optimizer step if fx0 is None).
101
- :code:`fx0` and :code:`fx0_approx` can be None.
102
-
103
98
  Args:
104
99
  closure (_ClosureType): closure
105
100
  params (TensorList): parameters
@@ -4,7 +4,7 @@ import torch
4
4
 
5
5
  from ...utils.python_tools import _ScalarLoss
6
6
  from ...tensorlist import TensorList
7
- from ...core import _ClosureType, OptimizerModule, OptimizationState
7
+ from ...core import _ClosureType, OptimizerModule, OptimizationVars
8
8
  from ._fd_formulas import _FD_Formulas
9
9
  from .base_approximator import GradientApproximatorBase
10
10
 
@@ -121,16 +121,16 @@ class NewtonFDM(OptimizerModule):
121
121
  self.tol = tol
122
122
 
123
123
  @torch.no_grad
124
- def step(self, state):
124
+ def step(self, vars):
125
125
  """Returns a new ascent direction."""
126
- if state.closure is None: raise ValueError('NewtonFDM requires a closure.')
127
- if state.ascent is not None: raise ValueError('NewtonFDM got ascent direction')
126
+ if vars.closure is None: raise ValueError('NewtonFDM requires a closure.')
127
+ if vars.ascent is not None: raise ValueError('NewtonFDM got ascent direction')
128
128
 
129
129
  params = self.get_params()
130
130
  epsilons = self.get_group_key('eps')
131
131
 
132
132
  # evaluate fx0.
133
- if state.fx0 is None: state.fx0 = state.closure(False)
133
+ if vars.fx0 is None: vars.fx0 = vars.closure(False)
134
134
 
135
135
  # evaluate gradients and hessian via finite differences.
136
136
  grads = params.zeros_like()
@@ -152,7 +152,7 @@ class NewtonFDM(OptimizerModule):
152
152
  cur2 += 1
153
153
  continue
154
154
  _three_point_2cd_(
155
- closure = state.closure,
155
+ closure = vars.closure,
156
156
  idx1 = idx1,
157
157
  idx2 = idx2,
158
158
  p1 = flat_param1,
@@ -161,7 +161,7 @@ class NewtonFDM(OptimizerModule):
161
161
  hessian = hessian,
162
162
  eps1 = eps1,
163
163
  eps2 = eps2,
164
- fx0 = state.fx0,
164
+ fx0 = vars.fx0,
165
165
  i1 = cur1,
166
166
  i2 = cur2,
167
167
  )
@@ -181,18 +181,18 @@ class NewtonFDM(OptimizerModule):
181
181
  newton_step, success = _fallback_gd(hessian, gvec)
182
182
 
183
183
  # update params or pass the gradients to the child.
184
- state.ascent = grads.from_vec(newton_step)
184
+ vars.ascent = grads.from_vec(newton_step)
185
185
 
186
186
 
187
187
  # validate if newton step decreased loss
188
188
  if self.validate:
189
189
 
190
- params.sub_(state.ascent)
191
- fx1 = state.closure(False)
192
- params.add_(state.ascent)
190
+ params.sub_(vars.ascent)
191
+ fx1 = vars.closure(False)
192
+ params.add_(vars.ascent)
193
193
 
194
194
  # if loss increases, set ascent direction to gvec times lr
195
- if fx1 - state.fx0 > state.fx0 * self.tol:
196
- state.ascent = grads.from_vec(gvec) * self.gd_lr
195
+ if fx1 - vars.fx0 > vars.fx0 * self.tol:
196
+ vars.ascent = grads.from_vec(gvec) * self.gd_lr
197
197
 
198
- return self._update_params_or_step_with_next(state, params)
198
+ return self._update_params_or_step_with_next(vars, params)
@@ -4,7 +4,7 @@ import torch
4
4
 
5
5
  from ...utils.python_tools import _ScalarLoss
6
6
  from ...tensorlist import Distributions, TensorList
7
- from ...core import _ClosureType, OptimizerModule, OptimizationState
7
+ from ...core import _ClosureType, OptimizerModule, OptimizationVars
8
8
  from ._fd_formulas import _FD_Formulas
9
9
  from .base_approximator import GradientApproximatorBase
10
10
 
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
 
3
3
  from ...tensorlist import TensorList
4
- from ...core import OptimizationState
4
+ from ...core import OptimizationVars
5
5
  from .base_ls import LineSearchBase
6
6
 
7
7
 
@@ -32,23 +32,23 @@ class ArmijoLS(LineSearchBase):
32
32
  self.max_iter = max_iter
33
33
 
34
34
  @torch.no_grad
35
- def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
36
- if state.closure is None: raise RuntimeError(f"Line searches ({self.__class__.__name__}) require a closure")
37
- ascent = state.maybe_use_grad_(params)
38
- grad = state.maybe_compute_grad_(params)
35
+ def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
36
+ if vars.closure is None: raise RuntimeError(f"Line searches ({self.__class__.__name__}) require a closure")
37
+ ascent = vars.maybe_use_grad_(params)
38
+ grad = vars.maybe_compute_grad_(params)
39
39
  alpha = self.get_first_group_key('alpha')
40
- if state.fx0 is None: state.fx0 = state.closure(False)
40
+ if vars.fx0 is None: vars.fx0 = vars.closure(False)
41
41
 
42
42
  # loss decrease per lr=1 if function was linear
43
43
  decrease_per_lr = (grad*ascent).total_sum()
44
44
 
45
45
  for _ in range(self.max_iter):
46
- loss = self._evaluate_lr_(alpha, state.closure, ascent, params)
46
+ loss = self._evaluate_lr_(alpha, vars.closure, ascent, params)
47
47
 
48
48
  # expected decrease
49
49
  expected_decrease = decrease_per_lr * alpha
50
50
 
51
- if (state.fx0 - loss) / expected_decrease >= self.beta:
51
+ if (vars.fx0 - loss) / expected_decrease >= self.beta:
52
52
  return alpha
53
53
 
54
54
  alpha *= self.mul
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
5
5
  import torch
6
6
 
7
7
  from ...tensorlist import TensorList
8
- from ...core import _ClosureType, OptimizationState, OptimizerModule, _maybe_pass_backward
8
+ from ...core import _ClosureType, OptimizationVars, OptimizerModule, _maybe_pass_backward
9
9
  from ...utils.python_tools import _ScalarLoss
10
10
 
11
11
 
@@ -108,20 +108,20 @@ class LineSearchBase(OptimizerModule, ABC):
108
108
  if isinstance(v, torch.Tensor): return v.detach().cpu().item()
109
109
  return float(v)
110
110
 
111
- def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
111
+ def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
112
112
  """This should return the best lr."""
113
113
  ... # pylint:disable=unnecessary-ellipsis
114
114
 
115
115
  @torch.no_grad
116
- def step(self, state: OptimizationState):
116
+ def step(self, vars: OptimizationVars):
117
117
  self._reset()
118
118
  if self.log_lrs: self._lrs.append({})
119
119
 
120
120
  params = self.get_params()
121
- ascent_direction = state.maybe_use_grad_(params)
121
+ ascent_direction = vars.maybe_use_grad_(params)
122
122
 
123
123
  try:
124
- lr = self._find_best_lr(state, params) # pylint:disable=assignment-from-no-return
124
+ lr = self._find_best_lr(vars, params) # pylint:disable=assignment-from-no-return
125
125
  except MaxIterReached:
126
126
  lr = self._best_lr
127
127
 
@@ -133,7 +133,7 @@ class LineSearchBase(OptimizerModule, ABC):
133
133
  # otherwise undo the update by setting lr to 0 and instead multiply ascent direction by lr.
134
134
  self._set_lr_(0, ascent_direction, params)
135
135
  ascent_direction.mul_(self._best_lr)
136
- state.ascent = ascent_direction
137
- if state.fx0_approx is None: state.fx0_approx = self._lowest_loss
138
- return self.next_module.step(state)
136
+ vars.ascent = ascent_direction
137
+ if vars.fx0_approx is None: vars.fx0_approx = self._lowest_loss
138
+ return self.next_module.step(vars)
139
139
 
@@ -2,7 +2,7 @@ import numpy as np
2
2
  import torch
3
3
 
4
4
  from ...tensorlist import TensorList
5
- from ...core import OptimizationState
5
+ from ...core import OptimizationVars
6
6
  from .base_ls import LineSearchBase
7
7
 
8
8
  _FloatOrTensor = float | torch.Tensor
@@ -57,14 +57,14 @@ class DirectionalNewton(LineSearchBase):
57
57
  self.validate_step = validate_step
58
58
 
59
59
  @torch.no_grad
60
- def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
61
- if state.closure is None: raise ValueError('QuardaticLS requires closure')
62
- closure = state.closure
60
+ def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
61
+ if vars.closure is None: raise ValueError('QuardaticLS requires closure')
62
+ closure = vars.closure
63
63
 
64
64
  params = self.get_params()
65
- grad = state.maybe_compute_grad_(params)
66
- ascent = state.maybe_use_grad_(params)
67
- if state.fx0 is None: state.fx0 = state.closure(False) # at this stage maybe_compute_grad could've evaluated fx0
65
+ grad = vars.maybe_compute_grad_(params)
66
+ ascent = vars.maybe_use_grad_(params)
67
+ if vars.fx0 is None: vars.fx0 = vars.closure(False) # at this stage maybe_compute_grad could've evaluated fx0
68
68
 
69
69
  alpha: float = self.get_first_group_key('alpha') # this doesn't support variable lrs but we still want to support schedulers
70
70
 
@@ -78,7 +78,7 @@ class DirectionalNewton(LineSearchBase):
78
78
  if y1_prime != 0:
79
79
  xmin, a = _fit_and_minimize_quadratic_2points_grad(
80
80
  x1=0,
81
- y1=state.fx0,
81
+ y1=vars.fx0,
82
82
  y1_prime=-y1_prime,
83
83
  x2=alpha,
84
84
  # we stepped in the direction of minus gradient times lr.
@@ -172,14 +172,14 @@ class DirectionalNewton3Points(LineSearchBase):
172
172
  self.validate_step = validate_step
173
173
 
174
174
  @torch.no_grad
175
- def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
176
- if state.closure is None: raise ValueError('QuardaticLS requires closure')
177
- closure = state.closure
178
- ascent_direction = state.ascent
175
+ def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
176
+ if vars.closure is None: raise ValueError('QuardaticLS requires closure')
177
+ closure = vars.closure
178
+ ascent_direction = vars.ascent
179
179
  if ascent_direction is None: raise ValueError('Ascent direction is None')
180
180
  alpha: float = self.get_first_group_key('alpha')
181
181
 
182
- if state.fx0 is None: state.fx0 = state.closure(False)
182
+ if vars.fx0 is None: vars.fx0 = vars.closure(False)
183
183
  params = self.get_params()
184
184
 
185
185
  # make a step in the direction and evaluate f(x2)
@@ -190,7 +190,7 @@ class DirectionalNewton3Points(LineSearchBase):
190
190
 
191
191
  # if gradients weren't 0
192
192
  xmin, a = _newton_step_3points(
193
- 0, state.fx0,
193
+ 0, vars.fx0,
194
194
  # we stepped in the direction of minus ascent_direction.
195
195
  alpha, y2,
196
196
  alpha * 2, y3
@@ -5,7 +5,7 @@ import numpy as np
5
5
  import torch
6
6
 
7
7
  from ...tensorlist import TensorList
8
- from ...core import _ClosureType, OptimizationState
8
+ from ...core import _ClosureType, OptimizationVars
9
9
  from .base_ls import LineSearchBase
10
10
 
11
11
  class GridLS(LineSearchBase):
@@ -34,16 +34,16 @@ class GridLS(LineSearchBase):
34
34
  self.stop_on_worsened = stop_on_worsened
35
35
 
36
36
  @torch.no_grad
37
- def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
38
- if state.closure is None: raise ValueError("closure is not set")
39
- if state.ascent is None: raise ValueError("ascent_direction is not set")
37
+ def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
38
+ if vars.closure is None: raise ValueError("closure is not set")
39
+ if vars.ascent is None: raise ValueError("ascent_direction is not set")
40
40
 
41
41
  if self.stop_on_improvement:
42
- if state.fx0 is None: state.fx0 = state.closure(False)
43
- self._lowest_loss = state.fx0
42
+ if vars.fx0 is None: vars.fx0 = vars.closure(False)
43
+ self._lowest_loss = vars.fx0
44
44
 
45
45
  for lr in self.lrs:
46
- loss = self._evaluate_lr_(float(lr), state.closure, state.ascent, params)
46
+ loss = self._evaluate_lr_(float(lr), vars.closure, vars.ascent, params)
47
47
 
48
48
  # if worsened
49
49
  if self.stop_on_worsened and loss != self._lowest_loss:
@@ -7,7 +7,7 @@ except ModuleNotFoundError:
7
7
  scopt = typing.cast(typing.Any, None)
8
8
 
9
9
  from ...tensorlist import TensorList
10
- from ...core import OptimizationState
10
+ from ...core import OptimizationVars
11
11
 
12
12
  from .base_ls import LineSearchBase, MaxIterReached
13
13
 
@@ -45,11 +45,11 @@ class ScipyMinimizeScalarLS(LineSearchBase):
45
45
  self.options = options
46
46
 
47
47
  @torch.no_grad
48
- def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
48
+ def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
49
49
  try:
50
50
  res = scopt.minimize_scalar(
51
51
  self._evaluate_lr_ensure_float,
52
- args = (state.closure, state.ascent, params),
52
+ args = (vars.closure, vars.ascent, params),
53
53
  method = self.method,
54
54
  tol = self.tol,
55
55
  bracket = self.bracket,
@@ -40,7 +40,7 @@ class Alternate(OptimizerModule):
40
40
  if len(self.mode) != len(self.children):
41
41
  raise ValueError(f"got {len(self.children)} modules but {len(mode)} repeats, they should be the same")
42
42
 
43
- def step(self, state):
43
+ def step(self, vars):
44
44
  if self.mode == 'random':
45
45
  module = self.random.choice(list(self.children.values()))
46
46
 
@@ -58,8 +58,8 @@ class Alternate(OptimizerModule):
58
58
  self.remaining -= 1
59
59
 
60
60
  if self.next_module is None:
61
- return module.step(state)
61
+ return module.step(vars)
62
62
 
63
- state.ascent = module.return_ascent(state)
64
- return self._update_params_or_step_with_next(state)
63
+ vars.ascent = module.return_ascent(vars)
64
+ return self._update_params_or_step_with_next(vars)
65
65