torchzero 0.1.3__tar.gz → 0.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. {torchzero-0.1.3/src/torchzero.egg-info → torchzero-0.1.5}/PKG-INFO +1 -1
  2. {torchzero-0.1.3 → torchzero-0.1.5}/pyproject.toml +1 -1
  3. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/core/__init__.py +1 -1
  4. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/core/module.py +72 -49
  5. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/core/tensorlist_optimizer.py +1 -1
  6. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/adaptive/adaptive.py +11 -11
  7. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/experimental/experimental.py +41 -41
  8. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/experimental/quad_interp.py +8 -8
  9. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/experimental/subspace.py +37 -37
  10. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/gradient_approximation/base_approximator.py +19 -24
  11. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/gradient_approximation/fdm.py +1 -1
  12. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/gradient_approximation/newton_fdm.py +13 -13
  13. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/gradient_approximation/rfdm.py +1 -1
  14. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/line_search/armijo.py +8 -8
  15. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/line_search/base_ls.py +8 -8
  16. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/line_search/directional_newton.py +14 -14
  17. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/line_search/grid_ls.py +7 -7
  18. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/line_search/scipy_minimize_scalar.py +3 -3
  19. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/meta/alternate.py +4 -4
  20. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/meta/grafting.py +23 -23
  21. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/meta/optimizer_wrapper.py +14 -14
  22. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/meta/return_overrides.py +8 -8
  23. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/misc/accumulate.py +6 -6
  24. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/misc/basic.py +16 -16
  25. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/misc/lr.py +2 -2
  26. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/misc/multistep.py +7 -7
  27. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/misc/on_increase.py +9 -9
  28. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/momentum/momentum.py +4 -4
  29. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/operations/multi.py +44 -44
  30. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/operations/reduction.py +28 -28
  31. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/operations/singular.py +9 -9
  32. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/optimizers/adagrad.py +1 -1
  33. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/optimizers/adam.py +8 -8
  34. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/optimizers/lion.py +1 -1
  35. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/optimizers/rmsprop.py +1 -1
  36. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/optimizers/rprop.py +1 -1
  37. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/optimizers/sgd.py +2 -2
  38. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/orthogonalization/newtonschulz.py +3 -3
  39. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/orthogonalization/svd.py +1 -1
  40. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/regularization/dropout.py +1 -1
  41. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/regularization/noise.py +3 -3
  42. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/regularization/normalization.py +5 -5
  43. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/regularization/ortho_grad.py +1 -1
  44. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/regularization/weight_decay.py +1 -1
  45. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/scheduling/lr_schedulers.py +2 -2
  46. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/scheduling/step_size.py +8 -8
  47. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/second_order/newton.py +12 -12
  48. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/smoothing/__init__.py +1 -1
  49. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/smoothing/gaussian_smoothing.py +7 -7
  50. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/smoothing/laplacian_smoothing.py +1 -1
  51. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/weight_averaging/ema.py +3 -3
  52. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/weight_averaging/swa.py +8 -8
  53. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/first_order/forward_gradient.py +1 -1
  54. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/modular.py +4 -4
  55. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/tensorlist.py +8 -1
  56. {torchzero-0.1.3 → torchzero-0.1.5/src/torchzero.egg-info}/PKG-INFO +1 -1
  57. {torchzero-0.1.3 → torchzero-0.1.5}/LICENSE +0 -0
  58. {torchzero-0.1.3 → torchzero-0.1.5}/README.md +0 -0
  59. {torchzero-0.1.3 → torchzero-0.1.5}/setup.cfg +0 -0
  60. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/__init__.py +0 -0
  61. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/__init__.py +0 -0
  62. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/adaptive/__init__.py +0 -0
  63. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/experimental/__init__.py +0 -0
  64. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/gradient_approximation/__init__.py +0 -0
  65. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/gradient_approximation/_fd_formulas.py +0 -0
  66. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/gradient_approximation/forward_gradient.py +0 -0
  67. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/line_search/__init__.py +0 -0
  68. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/meta/__init__.py +0 -0
  69. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/misc/__init__.py +0 -0
  70. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/momentum/__init__.py +0 -0
  71. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/operations/__init__.py +0 -0
  72. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/optimizers/__init__.py +0 -0
  73. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/orthogonalization/__init__.py +0 -0
  74. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/quasi_newton/__init__.py +0 -0
  75. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/regularization/__init__.py +0 -0
  76. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/scheduling/__init__.py +0 -0
  77. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/second_order/__init__.py +0 -0
  78. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/modules/weight_averaging/__init__.py +0 -0
  79. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/__init__.py +0 -0
  80. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/experimental/__init__.py +0 -0
  81. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/experimental/experimental.py +0 -0
  82. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/experimental/ray_search.py +0 -0
  83. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/first_order/__init__.py +0 -0
  84. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/first_order/cautious.py +0 -0
  85. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/first_order/optimizers.py +0 -0
  86. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/quasi_newton/__init__.py +0 -0
  87. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/quasi_newton/directional_newton.py +0 -0
  88. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/second_order/__init__.py +0 -0
  89. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/second_order/newton.py +0 -0
  90. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/wrappers/__init__.py +0 -0
  91. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/wrappers/nevergrad.py +0 -0
  92. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/wrappers/nlopt.py +0 -0
  93. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/wrappers/scipy.py +0 -0
  94. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/zeroth_order/__init__.py +0 -0
  95. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/zeroth_order/fdm.py +0 -0
  96. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/zeroth_order/newton_fdm.py +0 -0
  97. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/zeroth_order/rfdm.py +0 -0
  98. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/optim/zeroth_order/rs.py +0 -0
  99. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/random/__init__.py +0 -0
  100. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/random/random.py +0 -0
  101. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/utils/__init__.py +0 -0
  102. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/utils/compile.py +0 -0
  103. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/utils/derivatives.py +0 -0
  104. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/utils/python_tools.py +0 -0
  105. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero/utils/torch_tools.py +0 -0
  106. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero.egg-info/SOURCES.txt +0 -0
  107. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero.egg-info/dependency_links.txt +0 -0
  108. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero.egg-info/requires.txt +0 -0
  109. {torchzero-0.1.3 → torchzero-0.1.5}/src/torchzero.egg-info/top_level.txt +0 -0
  110. {torchzero-0.1.3 → torchzero-0.1.5}/tests/test_against_reference.py +0 -0
  111. {torchzero-0.1.3 → torchzero-0.1.5}/tests/test_modules.py +0 -0
  112. {torchzero-0.1.3 → torchzero-0.1.5}/tests/test_tensorlist.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: torchzero
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  License: MIT License
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
6
6
  name = "torchzero"
7
7
  description = "Modular optimization library for PyTorch."
8
8
 
9
- version = "0.1.3"
9
+ version = "0.1.5"
10
10
  dependencies = [
11
11
  "torch",
12
12
  "numpy",
@@ -1,7 +1,7 @@
1
1
  import sys
2
2
 
3
3
  from .module import (
4
- OptimizationState,
4
+ OptimizationVars,
5
5
  OptimizerModule,
6
6
  _Chain,
7
7
  _Chainable,
@@ -23,8 +23,8 @@ def _get_loss(fx0, fx0_approx):
23
23
  return fx0
24
24
 
25
25
 
26
- class OptimizationState:
27
- """Holds optimization state. This is usually automatically created by :any:`torchzero.optim.Modular`."""
26
+ class OptimizationVars:
27
+ """Holds optimization variables. This is usually automatically created by :any:`torchzero.optim.Modular`."""
28
28
  def __init__(self, closure: _ClosureType | None, model: torch.nn.Module | None):
29
29
 
30
30
  self.closure: _ClosureType | None = closure
@@ -121,23 +121,23 @@ class OptimizationState:
121
121
  Returns:
122
122
  A copy of this OptimizationState.
123
123
  """
124
- state = OptimizationState(self.closure, self.model)
125
- state.fx0 = self.fx0
126
- state.fx0_approx = self.fx0_approx
127
- state.grad = self.grad
124
+ vars = OptimizationVars(self.closure, self.model)
125
+ vars.fx0 = self.fx0
126
+ vars.fx0_approx = self.fx0_approx
127
+ vars.grad = self.grad
128
128
 
129
- if clone_ascent and self.ascent is not None: state.ascent = self.ascent.clone()
130
- else: state.ascent = self.ascent
129
+ if clone_ascent and self.ascent is not None: vars.ascent = self.ascent.clone()
130
+ else: vars.ascent = self.ascent
131
131
 
132
- return state
132
+ return vars
133
133
 
134
- def update_attrs_(self, state: "OptimizationState"):
134
+ def update_attrs_(self, vars: "OptimizationVars"):
135
135
  """Updates attributes of this state with attributes of another state.
136
136
 
137
137
  This updates `grad`, `fx0` and `fx0_approx`."""
138
- if state.grad is not None: self.grad = state.grad
139
- if state.fx0 is not None: self.fx0 = state.fx0
140
- if state.fx0_approx is not None: self.fx0_approx = state.fx0_approx
138
+ if vars.grad is not None: self.grad = vars.grad
139
+ if vars.fx0 is not None: self.fx0 = vars.fx0
140
+ if vars.fx0_approx is not None: self.fx0_approx = vars.fx0_approx
141
141
 
142
142
 
143
143
  def add_post_step_hook(self, hook: Callable):
@@ -283,7 +283,7 @@ class OptimizerModule(TensorListOptimizer, ABC): # type:ignore
283
283
  for c in self.children.values():
284
284
  self._update_child_params_(c)
285
285
 
286
- def _update_params_or_step_with_next(self, state: OptimizationState, params: TensorList | None = None) -> _ScalarLoss | None:
286
+ def _update_params_or_step_with_next(self, vars: OptimizationVars, params: TensorList | None = None) -> _ScalarLoss | None:
287
287
  """If this has no children, update params and return loss. Otherwise step with the next module.
288
288
 
289
289
  Optionally pass params to not recreate them if you've already made them.
@@ -293,29 +293,29 @@ class OptimizerModule(TensorListOptimizer, ABC): # type:ignore
293
293
  """
294
294
  # if this has no children, update params and return loss.
295
295
  if self.next_module is None:
296
- if state.ascent is None: raise ValueError('Called _update_params_or_step_with_child but ascent_direction is None...')
296
+ if vars.ascent is None: raise ValueError('Called _update_params_or_step_with_child but ascent_direction is None...')
297
297
  if params is None: params = self.get_params()
298
- params -= state.ascent # type:ignore
299
- return state.get_loss()
298
+ params -= vars.ascent # type:ignore
299
+ return vars.get_loss()
300
300
 
301
301
  # otherwise pass the updated ascent direction to the child
302
- return self.next_module.step(state)
302
+ return self.next_module.step(vars)
303
303
 
304
304
  @torch.no_grad
305
- def _step_update_closure(self, state: OptimizationState) -> _ScalarLoss | None:
305
+ def _step_update_closure(self, vars: OptimizationVars) -> _ScalarLoss | None:
306
306
  """Create a new closure which applies the `_update` method and passes it to the next module."""
307
- if state.closure is None: raise ValueError('If target == "closure", closure must be provided')
307
+ if vars.closure is None: raise ValueError('If target == "closure", closure must be provided')
308
308
 
309
309
  params = self.get_params()
310
- closure = state.closure # closure shouldn't reference state attribute because it can be changed
311
- ascent_direction = state.ascent
310
+ closure = vars.closure # closure shouldn't reference state attribute because it can be changed
311
+ ascent_direction = vars.ascent
312
312
 
313
313
  def update_closure(backward = True):
314
314
  loss = _maybe_pass_backward(closure, backward)
315
315
 
316
316
  # on backward, update the ascent direction
317
317
  if backward:
318
- grad = self._update(state, ascent_direction) # type:ignore
318
+ grad = self._update(vars, ascent_direction) # type:ignore
319
319
  # set new ascent direction as gradients
320
320
  # (accumulation doesn't make sense here as closure always calls zero_grad)
321
321
  for p, g in zip(params,grad):
@@ -327,12 +327,12 @@ class OptimizerModule(TensorListOptimizer, ABC): # type:ignore
327
327
  # if self.next_module is None:
328
328
  # raise ValueError(f'{self.__class__.__name__} has no child to step with (maybe set "target" from "closure" to something else??).')
329
329
 
330
- state.closure = update_closure
331
- return self._update_params_or_step_with_next(state)
330
+ vars.closure = update_closure
331
+ return self._update_params_or_step_with_next(vars)
332
332
 
333
333
 
334
334
  @torch.no_grad
335
- def _step_update_target(self, state: OptimizationState) -> _ScalarLoss | None:
335
+ def _step_update_target(self, vars: OptimizationVars) -> _ScalarLoss | None:
336
336
  """Apply _update method to the ascent direction and pass it to the child, or make a step if child is None."""
337
337
  # the following code by default uses `_update` method which simple modules can override.
338
338
  # But you can also just override the entire `step`.
@@ -342,50 +342,73 @@ class OptimizerModule(TensorListOptimizer, ABC): # type:ignore
342
342
  # update ascent direction
343
343
  if self._default_step_target == 'ascent':
344
344
  # if this is the first module, it uses the gradients
345
- if state.grad is None: params = self.get_params()
346
- t = state.maybe_use_grad_(params)
347
- state.ascent = self._update(state, t)
345
+ if vars.grad is None: params = self.get_params()
346
+ t = vars.maybe_use_grad_(params)
347
+ vars.ascent = self._update(vars, t)
348
348
 
349
349
  # update gradients
350
350
  elif self._default_step_target == 'grad':
351
351
  if params is None: params = self.get_params()
352
- g = state.maybe_compute_grad_(params)
353
- g = self._update(state, g)
354
- state.set_grad_(g, params)
352
+ g = vars.maybe_compute_grad_(params)
353
+ g = self._update(vars, g)
354
+ vars.set_grad_(g, params)
355
355
  else:
356
356
  raise ValueError(f"Invalid {self._default_step_target = }")
357
357
 
358
358
  # peform an update with the new state, or pass it to the child.
359
- return self._update_params_or_step_with_next(state, params=params)
359
+ return self._update_params_or_step_with_next(vars, params=params)
360
360
 
361
361
  @torch.no_grad
362
362
  def step( # type:ignore # pylint:disable=signature-differs # pylint:disable = arguments-renamed
363
363
  self,
364
- state: OptimizationState
364
+ vars: OptimizationVars
365
365
  ) -> _ScalarLoss | None:
366
366
  """Perform a single optimization step to update parameter."""
367
367
 
368
- if self._default_step_target == 'closure': return self._step_update_closure(state)
369
- return self._step_update_target(state)
368
+ if self._default_step_target == 'closure': return self._step_update_closure(vars)
369
+ return self._step_update_target(vars)
370
370
 
371
371
  @torch.no_grad
372
- def _update(self, state: OptimizationState, ascent: TensorList) -> TensorList:
372
+ def _update(self, vars: OptimizationVars, ascent: TensorList) -> TensorList:
373
373
  """Update `ascent_direction` and return the new ascent direction (but it may update it in place).
374
- Make sure it doesn't return anything from `state` to avoid future modules modifying that in-place.
374
+ Make sure it doesn't return anything from `self.state` to avoid future modules modifying that in-place.
375
375
 
376
376
  Before calling `_update`, if ascent direction was not provided to `step`, it will be set to the gradients.
377
377
 
378
378
  After generating a new ascent direction with this `_update` method,
379
379
  if this module has no child, ascent direction will be subtracted from params.
380
380
  Otherwise everything is passed to the child."""
381
+ params = self.get_params()
382
+ gradients = ascent.grad
383
+ if gradients is None: gradients = [None] * len(params)
384
+ settings = tuple(self.get_all_group_keys(list).items())
385
+
386
+ new_ascent = TensorList()
387
+ for i, (asc, param, grad) in enumerate(zip(ascent, params, gradients)):
388
+ kwargs = {"vars": vars, "ascent": asc, "param": param, "grad": grad}
389
+ kwargs.update({k:v[i] for k,v in settings})
390
+ new_ascent.append(self._single_tensor_update(**kwargs))
391
+ return new_ascent
392
+
393
+
394
+ def _single_tensor_update(self, vars: OptimizationVars, ascent: torch.Tensor, param: torch.Tensor, grad: torch.Tensor | None, **kwargs) -> torch.Tensor:
395
+ """Update function for a single tensor.
396
+
397
+ Args:
398
+ vars (OptimizationState): holds loss, gradients, etc.
399
+ ascent (torch.Tensor): update tensor.
400
+ param (torch.Tensor): parameter tensor.
401
+ grad (torch.Tensor | None): gradient tensor (may be None)
402
+ **kwargs: all per-parameter settings (stuff that you put in `defaults = dict(beta1=beta1, beta2=beta2, eps=eps)`).
403
+ """
381
404
  raise NotImplementedError()
382
405
 
383
- def return_ascent(self, state: OptimizationState, params=None) -> TensorList:
406
+ def return_ascent(self, vars: OptimizationVars, params=None) -> TensorList:
384
407
  """step with this module and return the ascent as tensorlist"""
385
408
  if params is None: params = self.get_params()
386
409
  true_next = self.next_module
387
410
  self.next_module = _ReturnAscent(params) # type:ignore
388
- ascent: TensorList = self.step(state) # type:ignore
411
+ ascent: TensorList = self.step(vars) # type:ignore
389
412
  self.next_module = true_next
390
413
  return ascent
391
414
 
@@ -412,8 +435,8 @@ class _ReturnAscent:
412
435
  self.next_module = None
413
436
 
414
437
  @torch.no_grad
415
- def step(self, state: OptimizationState) -> TensorList: # type:ignore
416
- update = state.maybe_use_grad_(self.params) # this will execute the closure which might be modified
438
+ def step(self, vars: OptimizationVars) -> TensorList: # type:ignore
439
+ update = vars.maybe_use_grad_(self.params) # this will execute the closure which might be modified
417
440
  return update
418
441
 
419
442
 
@@ -424,13 +447,13 @@ class _MaybeReturnAscent(OptimizerModule):
424
447
  self._return_ascent = False
425
448
 
426
449
  @torch.no_grad
427
- def step(self, state: OptimizationState):
450
+ def step(self, vars: OptimizationVars):
428
451
  assert self.next_module is None, self.next_module
429
452
 
430
453
  if self._return_ascent:
431
- return state.ascent
454
+ return vars.ascent
432
455
 
433
- return self._update_params_or_step_with_next(state)
456
+ return self._update_params_or_step_with_next(vars)
434
457
 
435
458
  _Chainable = OptimizerModule | Iterable[OptimizerModule]
436
459
 
@@ -456,16 +479,16 @@ class _Chain(OptimizerModule):
456
479
  self._chain_modules = flat_modules
457
480
 
458
481
  @torch.no_grad
459
- def step(self, state: OptimizationState):
482
+ def step(self, vars: OptimizationVars):
460
483
  # no next module, step with the child
461
484
  if self.next_module is None:
462
- return self.children['first'].step(state)
485
+ return self.children['first'].step(vars)
463
486
 
464
487
  # return ascent and pass it to next module
465
488
  # we do this because updating parameters directly is often more efficient
466
489
  params = self.get_params()
467
490
  self._last_module.next_module = _ReturnAscent(params) # type:ignore
468
- state.ascent: TensorList = self.children['first'].step(state) # type:ignore
491
+ vars.ascent: TensorList = self.children['first'].step(vars) # type:ignore
469
492
  self._last_module.next_module = None
470
493
 
471
- return self._update_params_or_step_with_next(state)
494
+ return self._update_params_or_step_with_next(vars)
@@ -149,7 +149,7 @@ class TensorListOptimizer(torch.optim.Optimizer, ABC):
149
149
 
150
150
  # def get_group_keys[CLS: MutableSequence](self, *keys: str, cls: type[CLS] = NumberList) -> list[CLS]:
151
151
  def get_group_keys(self, *keys: str, cls: type[CLS] = NumberList) -> list[CLS]:
152
- """Returns a TensorList with the param_groups `key` setting of each param."""
152
+ """Returns a list with the param_groups `key` setting of each param."""
153
153
 
154
154
  all_values: list[CLS] = [cls() for _ in keys]
155
155
  for group in self.param_groups:
@@ -33,16 +33,16 @@ class Cautious(OptimizerModule):
33
33
  self.mode: typing.Literal['zero', 'grad', 'backtrack'] = mode
34
34
 
35
35
  @torch.no_grad
36
- def _update(self, state, ascent):
36
+ def _update(self, vars, ascent):
37
37
  params = self.get_params()
38
- grad = state.maybe_compute_grad_(params)
38
+ grad = vars.maybe_compute_grad_(params)
39
39
 
40
40
  # mask will be > 0 for parameters where both signs are the same
41
41
  mask = (ascent * grad) > 0
42
42
  if self.mode in ('zero', 'grad'):
43
43
  if self.normalize and self.mode == 'zero':
44
44
  fmask = mask.to(ascent[0].dtype)
45
- fmask /= fmask.total_mean() + self.eps
45
+ fmask /= fmask.total_mean() + self.eps # type:ignore
46
46
  else:
47
47
  fmask = mask
48
48
 
@@ -66,9 +66,9 @@ class UseGradSign(OptimizerModule):
66
66
  super().__init__({})
67
67
 
68
68
  @torch.no_grad
69
- def _update(self, state, ascent):
69
+ def _update(self, vars, ascent):
70
70
  params = self.get_params()
71
- grad = state.maybe_compute_grad_(params)
71
+ grad = vars.maybe_compute_grad_(params)
72
72
 
73
73
  return ascent.abs_().mul_(grad.sign())
74
74
 
@@ -80,9 +80,9 @@ class UseGradMagnitude(OptimizerModule):
80
80
  super().__init__({})
81
81
 
82
82
  @torch.no_grad
83
- def _update(self, state, ascent):
83
+ def _update(self, vars, ascent):
84
84
  params = self.get_params()
85
- grad = state.maybe_compute_grad_(params)
85
+ grad = vars.maybe_compute_grad_(params)
86
86
 
87
87
  return ascent.sign_().mul_(grad.abs())
88
88
 
@@ -109,10 +109,10 @@ class ScaleLRBySignChange(OptimizerModule):
109
109
  self.use_grad = use_grad
110
110
 
111
111
  @torch.no_grad
112
- def _update(self, state, ascent):
112
+ def _update(self, vars, ascent):
113
113
  params = self.get_params()
114
114
 
115
- if self.use_grad: cur = state.maybe_compute_grad_(params)
115
+ if self.use_grad: cur = vars.maybe_compute_grad_(params)
116
116
  else: cur = ascent
117
117
 
118
118
  nplus, nminus, lb, ub = self.get_group_keys('nplus', 'nminus', 'lb', 'ub')
@@ -168,10 +168,10 @@ class NegateOnSignChange(OptimizerModule):
168
168
  self.current_step = 0
169
169
 
170
170
  @torch.no_grad
171
- def _update(self, state, ascent):
171
+ def _update(self, vars, ascent):
172
172
  params = self.get_params()
173
173
 
174
- if self.use_grad: cur = state.maybe_compute_grad_(params)
174
+ if self.use_grad: cur = vars.maybe_compute_grad_(params)
175
175
  else: cur = ascent
176
176
 
177
177
  prev = self.get_state_key('prev')
@@ -35,9 +35,9 @@ class MinibatchRprop(OptimizerModule):
35
35
  self.next_mode = next_mode
36
36
 
37
37
  @torch.no_grad
38
- def step(self, state):
39
- if state.closure is None: raise ValueError("Minibatch Rprop requires closure")
40
- if state.ascent is not None: raise ValueError("Minibatch Rprop must be the first module.")
38
+ def step(self, vars):
39
+ if vars.closure is None: raise ValueError("Minibatch Rprop requires closure")
40
+ if vars.ascent is not None: raise ValueError("Minibatch Rprop must be the first module.")
41
41
  params = self.get_params()
42
42
 
43
43
  nplus, nminus, lb, ub = self.get_group_keys('nplus', 'nminus', 'lb', 'ub')
@@ -47,7 +47,7 @@ class MinibatchRprop(OptimizerModule):
47
47
  params=params
48
48
  )
49
49
 
50
- g1_sign = state.maybe_compute_grad_(params).sign() # no inplace to not modify grads
50
+ g1_sign = vars.maybe_compute_grad_(params).sign() # no inplace to not modify grads
51
51
  # initialize on 1st iteration
52
52
  if self.current_step == 0:
53
53
  magnitudes.fill_(self.get_group_key('alpha')).clamp_(lb, ub)
@@ -58,8 +58,8 @@ class MinibatchRprop(OptimizerModule):
58
58
  # first step
59
59
  ascent = g1_sign.mul_(magnitudes).mul_(allowed)
60
60
  params -= ascent
61
- with torch.enable_grad(): state.fx0_approx = state.closure()
62
- f0 = state.fx0; f1 = state.fx0_approx
61
+ with torch.enable_grad(): vars.fx0_approx = vars.closure()
62
+ f0 = vars.fx0; f1 = vars.fx0_approx
63
63
  assert f0 is not None and f1 is not None
64
64
 
65
65
  # if loss increased, reduce all lrs and undo the update
@@ -73,9 +73,9 @@ class MinibatchRprop(OptimizerModule):
73
73
  # on `continue` we move to params after 1st update
74
74
  # therefore state must be updated to have all attributes after 1st update
75
75
  if self.next_mode == 'continue':
76
- state.fx0 = state.fx0_approx
77
- state.grad = params.ensure_grad_().grad
78
- sign = state.grad.sign()
76
+ vars.fx0 = vars.fx0_approx
77
+ vars.grad = params.ensure_grad_().grad
78
+ sign = vars.grad.sign()
79
79
 
80
80
  else:
81
81
  sign = params.ensure_grad_().grad.sign_() # can use in-place as this is not fx0 grad
@@ -109,19 +109,19 @@ class MinibatchRprop(OptimizerModule):
109
109
 
110
110
  # update params or step
111
111
  if self.next_mode == 'continue' or (self.next_mode == 'add' and self.next_module is None):
112
- state.ascent = ascent2
113
- return self._update_params_or_step_with_next(state, params)
112
+ vars.ascent = ascent2
113
+ return self._update_params_or_step_with_next(vars, params)
114
114
 
115
115
  if self.next_mode == 'add':
116
116
  # undo 1st step
117
117
  params += ascent
118
- state.ascent = ascent + ascent2
119
- return self._update_params_or_step_with_next(state, params)
118
+ vars.ascent = ascent + ascent2
119
+ return self._update_params_or_step_with_next(vars, params)
120
120
 
121
121
  if self.next_mode == 'undo':
122
122
  params += ascent
123
- state.ascent = ascent2
124
- return self._update_params_or_step_with_next(state, params)
123
+ vars.ascent = ascent2
124
+ return self._update_params_or_step_with_next(vars, params)
125
125
 
126
126
  raise ValueError(f'invalid next_mode: {self.next_mode}')
127
127
 
@@ -140,9 +140,9 @@ class GradMin(OptimizerModule):
140
140
  self.create_graph = create_graph
141
141
 
142
142
  @torch.no_grad
143
- def step(self, state):
144
- if state.closure is None: raise ValueError()
145
- if state.ascent is not None:
143
+ def step(self, vars):
144
+ if vars.closure is None: raise ValueError()
145
+ if vars.ascent is not None:
146
146
  raise ValueError("GradMin doesn't accept ascent_direction")
147
147
 
148
148
  params = self.get_params()
@@ -150,26 +150,26 @@ class GradMin(OptimizerModule):
150
150
 
151
151
  self.zero_grad()
152
152
  with torch.enable_grad():
153
- state.fx0 = state.closure(False)
154
- grads = jacobian([state.fx0], params, create_graph=True, batched=False) # type:ignore
153
+ vars.fx0 = vars.closure(False)
154
+ grads = jacobian([vars.fx0], params, create_graph=True, batched=False) # type:ignore
155
155
  grads = TensorList(grads).squeeze_(0)
156
156
  if self.square:
157
157
  grads = grads ** 2
158
158
  else:
159
159
  grads = grads.abs()
160
160
 
161
- if self.maximize_grad: grads: TensorList = grads - (state.fx0 * loss_term) # type:ignore
162
- else: grads = grads + (state.fx0 * loss_term)
161
+ if self.maximize_grad: grads: TensorList = grads - (vars.fx0 * loss_term) # type:ignore
162
+ else: grads = grads + (vars.fx0 * loss_term)
163
163
  grad_mean = torch.sum(torch.stack(grads.sum())) / grads.total_numel()
164
164
 
165
165
  if self.create_graph: grad_mean.backward(create_graph=True)
166
166
  else: grad_mean.backward(retain_graph=False)
167
167
 
168
- if self.maximize_grad: state.grad = params.ensure_grad_().grad.neg_()
169
- else: state.grad = params.ensure_grad_().grad
168
+ if self.maximize_grad: vars.grad = params.ensure_grad_().grad.neg_()
169
+ else: vars.grad = params.ensure_grad_().grad
170
170
 
171
- state.maybe_use_grad_(params)
172
- return self._update_params_or_step_with_next(state)
171
+ vars.maybe_use_grad_(params)
172
+ return self._update_params_or_step_with_next(vars)
173
173
 
174
174
 
175
175
  class HVPDiagNewton(OptimizerModule):
@@ -182,26 +182,26 @@ class HVPDiagNewton(OptimizerModule):
182
182
  super().__init__(dict(eps=eps))
183
183
 
184
184
  @torch.no_grad
185
- def step(self, state):
186
- if state.closure is None: raise ValueError()
187
- if state.ascent is not None:
185
+ def step(self, vars):
186
+ if vars.closure is None: raise ValueError()
187
+ if vars.ascent is not None:
188
188
  raise ValueError("HVPDiagNewton doesn't accept ascent_direction")
189
189
 
190
190
  params = self.get_params()
191
191
  eps = self.get_group_key('eps')
192
- grad_fx0 = state.maybe_compute_grad_(params).clone()
193
- state.grad = grad_fx0 # set state grad to the cloned version, since it will be overwritten
192
+ grad_fx0 = vars.maybe_compute_grad_(params).clone()
193
+ vars.grad = grad_fx0 # set state grad to the cloned version, since it will be overwritten
194
194
 
195
195
  params += grad_fx0 * eps
196
- with torch.enable_grad(): _ = state.closure()
196
+ with torch.enable_grad(): _ = vars.closure()
197
197
 
198
198
  params -= grad_fx0 * eps
199
199
 
200
200
  newton = grad_fx0 * ((grad_fx0 * eps) / (params.grad - grad_fx0))
201
201
  newton.nan_to_num_(0,0,0)
202
202
 
203
- state.ascent = newton
204
- return self._update_params_or_step_with_next(state)
203
+ vars.ascent = newton
204
+ return self._update_params_or_step_with_next(vars)
205
205
 
206
206
 
207
207
 
@@ -219,11 +219,11 @@ class ReduceOutwardLR(OptimizerModule):
219
219
  self.invert = invert
220
220
 
221
221
  @torch.no_grad
222
- def _update(self, state, ascent):
222
+ def _update(self, vars, ascent):
223
223
  params = self.get_params()
224
224
  mul = self.get_group_key('mul')
225
225
 
226
- if self.use_grad: cur = state.maybe_compute_grad_(params)
226
+ if self.use_grad: cur = vars.maybe_compute_grad_(params)
227
227
  else: cur = ascent
228
228
 
229
229
  # mask of weights where sign matches with update sign (minus ascent sign), multiplied by `mul`.
@@ -241,7 +241,7 @@ class NoiseSign(OptimizerModule):
241
241
  self.distribution:Distributions = distribution
242
242
 
243
243
 
244
- def _update(self, state, ascent):
244
+ def _update(self, vars, ascent):
245
245
  return ascent.sample_like(self.alpha, self.distribution).copysign_(ascent)
246
246
 
247
247
  class ParamSign(OptimizerModule):
@@ -250,7 +250,7 @@ class ParamSign(OptimizerModule):
250
250
  super().__init__({})
251
251
 
252
252
 
253
- def _update(self, state, ascent):
253
+ def _update(self, vars, ascent):
254
254
  params = self.get_params()
255
255
 
256
256
  return params.copysign(ascent)
@@ -261,7 +261,7 @@ class NegParamSign(OptimizerModule):
261
261
  super().__init__({})
262
262
 
263
263
 
264
- def _update(self, state, ascent):
264
+ def _update(self, vars, ascent):
265
265
  neg_params = self.get_params().abs()
266
266
  max = neg_params.total_max()
267
267
  neg_params = neg_params.neg_().add(max)
@@ -274,7 +274,7 @@ class InvParamSign(OptimizerModule):
274
274
  self.eps = eps
275
275
 
276
276
 
277
- def _update(self, state, ascent):
277
+ def _update(self, vars, ascent):
278
278
  inv_params = self.get_params().abs().add_(self.eps).reciprocal_()
279
279
  return inv_params.copysign(ascent)
280
280
 
@@ -286,7 +286,7 @@ class ParamWhereConsistentSign(OptimizerModule):
286
286
  self.eps = eps
287
287
 
288
288
 
289
- def _update(self, state, ascent):
289
+ def _update(self, vars, ascent):
290
290
  params = self.get_params()
291
291
  same_sign = params.sign() == ascent.sign()
292
292
  ascent.masked_set_(same_sign, params)
@@ -4,7 +4,7 @@ import numpy as np
4
4
  import torch
5
5
 
6
6
  from ...tensorlist import TensorList
7
- from ...core import OptimizationState
7
+ from ...core import OptimizationVars
8
8
  from ..line_search.base_ls import LineSearchBase
9
9
 
10
10
  _FloatOrTensor = float | torch.Tensor
@@ -47,12 +47,12 @@ class QuadraticInterpolation2Point(LineSearchBase):
47
47
  self.min_dist = min_dist
48
48
 
49
49
  @torch.no_grad
50
- def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
51
- if state.closure is None: raise ValueError('QuardaticLS requires closure')
52
- closure = state.closure
53
- if state.fx0 is None: state.fx0 = state.closure(False)
54
- grad = state.grad
55
- if grad is None: grad = state.ascent # in case we used FDM
50
+ def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
51
+ if vars.closure is None: raise ValueError('QuardaticLS requires closure')
52
+ closure = vars.closure
53
+ if vars.fx0 is None: vars.fx0 = vars.closure(False)
54
+ grad = vars.grad
55
+ if grad is None: grad = vars.ascent # in case we used FDM
56
56
  if grad is None: raise ValueError('QuardaticLS requires gradients.')
57
57
 
58
58
  params = self.get_params()
@@ -67,7 +67,7 @@ class QuadraticInterpolation2Point(LineSearchBase):
67
67
 
68
68
  # make min_dist relative
69
69
  min_dist = abs(lr) * self.min_dist
70
- points = sorted([Point(0, _ensure_float(state.fx0), dfx0), Point(lr, _ensure_float(fx1))], key = lambda x: x.fx)
70
+ points = sorted([Point(0, _ensure_float(vars.fx0), dfx0), Point(lr, _ensure_float(fx1))], key = lambda x: x.fx)
71
71
 
72
72
  for i in range(self.max_evals):
73
73
  # find new point