torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +54 -21
  2. tests/test_tensorlist.py +2 -2
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +19 -129
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +12 -12
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +67 -17
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +12 -12
  78. torchzero/modules/quasi_newton/lsr1.py +11 -11
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +254 -47
  81. torchzero/modules/second_order/newton.py +32 -20
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +21 -21
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.9.dist-info/RECORD +0 -131
  107. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ import torch
6
6
  from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel, vec_to_tensors
7
7
  from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
8
 
9
- from ...core import Chainable, apply, Module
9
+ from ...core import Chainable, apply_transform, Module
10
10
  from ...utils.linalg.solve import nystrom_sketch_and_solve, nystrom_pcg
11
11
 
12
12
  class NystromSketchAndSolve(Module):
@@ -15,7 +15,7 @@ class NystromSketchAndSolve(Module):
15
15
  rank: int,
16
16
  reg: float = 1e-3,
17
17
  hvp_method: Literal["forward", "central", "autograd"] = "autograd",
18
- h=1e-2,
18
+ h=1e-3,
19
19
  inner: Chainable | None = None,
20
20
  seed: int | None = None,
21
21
  ):
@@ -26,10 +26,10 @@ class NystromSketchAndSolve(Module):
26
26
  self.set_child('inner', inner)
27
27
 
28
28
  @torch.no_grad
29
- def step(self, vars):
30
- params = TensorList(vars.params)
29
+ def step(self, var):
30
+ params = TensorList(var.params)
31
31
 
32
- closure = vars.closure
32
+ closure = var.closure
33
33
  if closure is None: raise RuntimeError('NewtonCG requires closure')
34
34
 
35
35
  settings = self.settings[params[0]]
@@ -47,7 +47,7 @@ class NystromSketchAndSolve(Module):
47
47
 
48
48
  # ---------------------- Hessian vector product function --------------------- #
49
49
  if hvp_method == 'autograd':
50
- grad = vars.get_grad(create_graph=True)
50
+ grad = var.get_grad(create_graph=True)
51
51
 
52
52
  def H_mm(x):
53
53
  with torch.enable_grad():
@@ -57,7 +57,7 @@ class NystromSketchAndSolve(Module):
57
57
  else:
58
58
 
59
59
  with torch.enable_grad():
60
- grad = vars.get_grad()
60
+ grad = var.get_grad()
61
61
 
62
62
  if hvp_method == 'forward':
63
63
  def H_mm(x):
@@ -74,14 +74,14 @@ class NystromSketchAndSolve(Module):
74
74
 
75
75
 
76
76
  # -------------------------------- inner step -------------------------------- #
77
- b = vars.get_update()
77
+ b = var.get_update()
78
78
  if 'inner' in self.children:
79
- b = apply(self.children['inner'], b, params=params, grads=grad, vars=vars)
79
+ b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
80
80
 
81
81
  # ------------------------------ sketch&n&solve ------------------------------ #
82
82
  x = nystrom_sketch_and_solve(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), rank=rank, reg=reg, generator=generator)
83
- vars.update = vec_to_tensors(x, reference=params)
84
- return vars
83
+ var.update = vec_to_tensors(x, reference=params)
84
+ return var
85
85
 
86
86
 
87
87
 
@@ -93,7 +93,7 @@ class NystromPCG(Module):
93
93
  tol=1e-3,
94
94
  reg: float = 1e-6,
95
95
  hvp_method: Literal["forward", "central", "autograd"] = "autograd",
96
- h=1e-2,
96
+ h=1e-3,
97
97
  inner: Chainable | None = None,
98
98
  seed: int | None = None,
99
99
  ):
@@ -104,10 +104,10 @@ class NystromPCG(Module):
104
104
  self.set_child('inner', inner)
105
105
 
106
106
  @torch.no_grad
107
- def step(self, vars):
108
- params = TensorList(vars.params)
107
+ def step(self, var):
108
+ params = TensorList(var.params)
109
109
 
110
- closure = vars.closure
110
+ closure = var.closure
111
111
  if closure is None: raise RuntimeError('NewtonCG requires closure')
112
112
 
113
113
  settings = self.settings[params[0]]
@@ -129,7 +129,7 @@ class NystromPCG(Module):
129
129
 
130
130
  # ---------------------- Hessian vector product function --------------------- #
131
131
  if hvp_method == 'autograd':
132
- grad = vars.get_grad(create_graph=True)
132
+ grad = var.get_grad(create_graph=True)
133
133
 
134
134
  def H_mm(x):
135
135
  with torch.enable_grad():
@@ -139,7 +139,7 @@ class NystromPCG(Module):
139
139
  else:
140
140
 
141
141
  with torch.enable_grad():
142
- grad = vars.get_grad()
142
+ grad = var.get_grad()
143
143
 
144
144
  if hvp_method == 'forward':
145
145
  def H_mm(x):
@@ -156,13 +156,13 @@ class NystromPCG(Module):
156
156
 
157
157
 
158
158
  # -------------------------------- inner step -------------------------------- #
159
- b = vars.get_update()
159
+ b = var.get_update()
160
160
  if 'inner' in self.children:
161
- b = apply(self.children['inner'], b, params=params, grads=grad, vars=vars)
161
+ b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
162
162
 
163
163
  # ------------------------------ sketch&n&solve ------------------------------ #
164
164
  x = nystrom_pcg(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), sketch_size=sketch_size, reg=reg, tol=tol, maxiter=maxiter, x0_=None, generator=generator)
165
- vars.update = vec_to_tensors(x, reference=params)
166
- return vars
165
+ var.update = vec_to_tensors(x, reference=params)
166
+ return var
167
167
 
168
168
 
@@ -6,7 +6,7 @@ from typing import Literal
6
6
 
7
7
  import torch
8
8
 
9
- from ...core import Modular, Module, Vars
9
+ from ...core import Modular, Module, Var
10
10
  from ...utils import NumberList, TensorList
11
11
  from ...utils.derivatives import jacobian_wrt
12
12
  from ..grad_approximation import GradApproximator, GradTarget
@@ -17,24 +17,24 @@ class Reformulation(Module, ABC):
17
17
  super().__init__(defaults)
18
18
 
19
19
  @abstractmethod
20
- def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], vars: Vars) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
20
+ def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], var: Var) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
21
21
  """returns loss and gradient, if backward is False then gradient can be None"""
22
22
 
23
- def pre_step(self, vars: Vars) -> Vars | None:
23
+ def pre_step(self, var: Var) -> Var | None:
24
24
  """This runs once before each step, whereas `closure` may run multiple times per step if further modules
25
25
  evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
26
- return vars
26
+ return var
27
27
 
28
- def step(self, vars):
29
- ret = self.pre_step(vars)
30
- if isinstance(ret, Vars): vars = ret
28
+ def step(self, var):
29
+ ret = self.pre_step(var)
30
+ if isinstance(ret, Var): var = ret
31
31
 
32
- if vars.closure is None: raise RuntimeError("Reformulation requires closure")
33
- params, closure = vars.params, vars.closure
32
+ if var.closure is None: raise RuntimeError("Reformulation requires closure")
33
+ params, closure = var.params, var.closure
34
34
 
35
35
 
36
36
  def modified_closure(backward=True):
37
- loss, grad = self.closure(backward, closure, params, vars)
37
+ loss, grad = self.closure(backward, closure, params, var)
38
38
 
39
39
  if grad is not None:
40
40
  for p,g in zip(params, grad):
@@ -42,8 +42,8 @@ class Reformulation(Module, ABC):
42
42
 
43
43
  return loss
44
44
 
45
- vars.closure = modified_closure
46
- return vars
45
+ var.closure = modified_closure
46
+ return var
47
47
 
48
48
 
49
49
  def _decay_sigma_(self: Module, params):
@@ -58,7 +58,7 @@ def _generate_perturbations_to_state_(self: Module, params: TensorList, n_sample
58
58
  for param, prt in zip(params, zip(*perturbations)):
59
59
  self.state[param]['perturbations'] = prt
60
60
 
61
- def _clear_state_hook(optimizer: Modular, vars: Vars, self: Module):
61
+ def _clear_state_hook(optimizer: Modular, var: Var, self: Module):
62
62
  for m in optimizer.unrolled_modules:
63
63
  if m is not self:
64
64
  m.reset()
@@ -85,12 +85,12 @@ class GaussianHomotopy(Reformulation):
85
85
  else: self.global_state['generator'] = None
86
86
  return self.global_state['generator']
87
87
 
88
- def pre_step(self, vars):
89
- params = TensorList(vars.params)
88
+ def pre_step(self, var):
89
+ params = TensorList(var.params)
90
90
  settings = self.settings[params[0]]
91
91
  n_samples = settings['n_samples']
92
- init_sigma = self.get_settings('init_sigma', params=params)
93
- sigmas = self.get_state('sigma', params = params, init=init_sigma)
92
+ init_sigma = [self.settings[p]['init_sigma'] for p in params]
93
+ sigmas = self.get_state(params, 'sigma', init=init_sigma)
94
94
 
95
95
  if any('perturbations' not in self.state[p] for p in params):
96
96
  generator = self._get_generator(settings['seed'], params)
@@ -109,9 +109,9 @@ class GaussianHomotopy(Reformulation):
109
109
  tol = settings['tol']
110
110
  if tol is not None and not decayed:
111
111
  if not any('prev_params' in self.state[p] for p in params):
112
- prev_params = self.get_state('prev_params', params=params, cls=TensorList, init='param')
112
+ prev_params = self.get_state(params, 'prev_params', cls=TensorList, init='param')
113
113
  else:
114
- prev_params = self.get_state('prev_params', params=params, cls=TensorList, init='param')
114
+ prev_params = self.get_state(params, 'prev_params', cls=TensorList, init='param')
115
115
  s = params - prev_params
116
116
 
117
117
  if s.abs().global_max() <= tol:
@@ -124,10 +124,10 @@ class GaussianHomotopy(Reformulation):
124
124
  generator = self._get_generator(settings['seed'], params)
125
125
  _generate_perturbations_to_state_(self, params=params, n_samples=n_samples, sigmas=sigmas, generator=generator)
126
126
  if settings['clear_state']:
127
- vars.post_step_hooks.append(partial(_clear_state_hook, self=self))
127
+ var.post_step_hooks.append(partial(_clear_state_hook, self=self))
128
128
 
129
129
  @torch.no_grad
130
- def closure(self, backward, closure, params, vars):
130
+ def closure(self, backward, closure, params, var):
131
131
  params = TensorList(params)
132
132
 
133
133
  settings = self.settings[params[0]]
@@ -67,7 +67,7 @@ class LaplacianSmoothing(Transform):
67
67
  minimum number of elements in a parameter to apply laplacian smoothing to.
68
68
  Only has effect if `layerwise` is True. Defaults to 4.
69
69
  target (str, optional):
70
- what to set on vars.
70
+ what to set on var.
71
71
 
72
72
  Reference:
73
73
  *Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
@@ -82,19 +82,17 @@ class LaplacianSmoothing(Transform):
82
82
 
83
83
 
84
84
  @torch.no_grad
85
- def transform(self, tensors, params, grads, vars):
86
- layerwise = self.settings[params[0]]['layerwise']
85
+ def apply(self, tensors, params, grads, loss, states, settings):
86
+ layerwise = settings[0]['layerwise']
87
87
 
88
88
  # layerwise laplacian smoothing
89
89
  if layerwise:
90
90
 
91
91
  # precompute the denominator for each layer and store it in each parameters state
92
92
  smoothed_target = TensorList()
93
- for p, t in zip(params, tensors):
94
- settings = self.settings[p]
95
- if p.numel() > settings['min_numel']:
96
- state = self.state[p]
97
- if 'denominator' not in state: state['denominator'] = _precompute_denominator(p, settings['sigma'])
93
+ for p, t, state, setting in zip(params, tensors, states, settings):
94
+ if p.numel() > setting['min_numel']:
95
+ if 'denominator' not in state: state['denominator'] = _precompute_denominator(p, setting['sigma'])
98
96
  smoothed_target.append(torch.fft.ifft(torch.fft.fft(t.view(-1)) / state['denominator']).real.view_as(t)) #pylint:disable=not-callable
99
97
  else:
100
98
  smoothed_target.append(t)
@@ -106,7 +104,7 @@ class LaplacianSmoothing(Transform):
106
104
  # precompute full denominator
107
105
  tensors = TensorList(tensors)
108
106
  if self.global_state.get('full_denominator', None) is None:
109
- self.global_state['full_denominator'] = _precompute_denominator(tensors.to_vec(), self.settings[params[0]]['sigma'])
107
+ self.global_state['full_denominator'] = _precompute_denominator(tensors.to_vec(), settings[0]['sigma'])
110
108
 
111
109
  # apply the smoothing
112
110
  vec = tensors.to_vec()
@@ -1 +1 @@
1
- from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_
1
+ from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, NormalizedWeightDecay
@@ -1,9 +1,11 @@
1
1
  from collections.abc import Iterable, Sequence
2
+ from typing import Literal
2
3
 
3
4
  import torch
4
5
 
5
6
  from ...core import Module, Target, Transform
6
- from ...utils import NumberList, TensorList, as_tensorlist
7
+ from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
8
+
7
9
 
8
10
  @torch.no_grad
9
11
  def weight_decay_(
@@ -25,12 +27,44 @@ class WeightDecay(Transform):
25
27
  super().__init__(defaults, uses_grad=False, target=target)
26
28
 
27
29
  @torch.no_grad
28
- def transform(self, tensors, params, grads, vars):
29
- weight_decay = self.get_settings('weight_decay', params=params, cls=NumberList)
30
- ord = self.settings[params[0]]['ord']
30
+ def apply(self, tensors, params, grads, loss, states, settings):
31
+ weight_decay = NumberList(s['weight_decay'] for s in settings)
32
+ ord = settings[0]['ord']
31
33
 
32
34
  return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)
33
35
 
36
+ class NormalizedWeightDecay(Transform):
37
+ def __init__(
38
+ self,
39
+ weight_decay: float = 0.1,
40
+ ord: int = 2,
41
+ norm_input: Literal["update", "grad", "params"] = "update",
42
+ target: Target = "update",
43
+ ):
44
+ defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input)
45
+ super().__init__(defaults, uses_grad=norm_input == 'grad', target=target)
46
+
47
+ @torch.no_grad
48
+ def apply(self, tensors, params, grads, loss, states, settings):
49
+ weight_decay = NumberList(s['weight_decay'] for s in settings)
50
+
51
+ ord = settings[0]['ord']
52
+ norm_input = settings[0]['norm_input']
53
+
54
+ if norm_input == 'update': src = TensorList(tensors)
55
+ elif norm_input == 'grad':
56
+ assert grads is not None
57
+ src = TensorList(grads)
58
+ elif norm_input == 'params':
59
+ src = TensorList(params)
60
+ else:
61
+ raise ValueError(norm_input)
62
+
63
+ norm = src.global_vector_norm(ord)
64
+
65
+ return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay * norm, ord)
66
+
67
+
34
68
  @torch.no_grad
35
69
  def decay_weights_(params: Iterable[torch.Tensor], weight_decay: float | NumberList, ord:int=2):
36
70
  """directly decays weights in-place"""
@@ -44,9 +78,9 @@ class DirectWeightDecay(Module):
44
78
  super().__init__(defaults)
45
79
 
46
80
  @torch.no_grad
47
- def step(self, vars):
48
- weight_decay = self.get_settings('weight_decay', params=vars.params, cls=NumberList)
49
- ord = self.settings[vars.params[0]]['ord']
81
+ def step(self, var):
82
+ weight_decay = self.get_settings(var.params, 'weight_decay', cls=NumberList)
83
+ ord = self.settings[var.params[0]]['ord']
50
84
 
51
- decay_weights_(vars.params, weight_decay, ord)
52
- return vars
85
+ decay_weights_(var.params, weight_decay, ord)
86
+ return var
@@ -24,8 +24,8 @@ class Wrap(Module):
24
24
  return super().set_param_groups(param_groups)
25
25
 
26
26
  @torch.no_grad
27
- def step(self, vars):
28
- params = vars.params
27
+ def step(self, var):
28
+ params = var.params
29
29
 
30
30
  # initialize opt on 1st step
31
31
  if self.optimizer is None:
@@ -35,18 +35,18 @@ class Wrap(Module):
35
35
 
36
36
  # set grad to update
37
37
  orig_grad = [p.grad for p in params]
38
- for p, u in zip(params, vars.get_update()):
38
+ for p, u in zip(params, var.get_update()):
39
39
  p.grad = u
40
40
 
41
41
  # if this module is last, can step with _opt directly
42
42
  # direct step can't be applied if next module is LR but _opt doesn't support lr,
43
43
  # and if there are multiple different per-parameter lrs (would be annoying to support)
44
- if vars.is_last and (
45
- (vars.last_module_lrs is None)
44
+ if var.is_last and (
45
+ (var.last_module_lrs is None)
46
46
  or
47
- (('lr' in self.optimizer.defaults) and (len(set(vars.last_module_lrs)) == 1))
47
+ (('lr' in self.optimizer.defaults) and (len(set(var.last_module_lrs)) == 1))
48
48
  ):
49
- lr = 1 if vars.last_module_lrs is None else vars.last_module_lrs[0]
49
+ lr = 1 if var.last_module_lrs is None else var.last_module_lrs[0]
50
50
 
51
51
  # update optimizer lr with desired lr
52
52
  if lr != 1:
@@ -68,19 +68,19 @@ class Wrap(Module):
68
68
  for p, g in zip(params, orig_grad):
69
69
  p.grad = g
70
70
 
71
- vars.stop = True; vars.skip_update = True
72
- return vars
71
+ var.stop = True; var.skip_update = True
72
+ return var
73
73
 
74
74
  # this is not the last module, meaning update is difference in parameters
75
75
  params_before_step = [p.clone() for p in params]
76
76
  self.optimizer.step() # step and update params
77
77
  for p, g in zip(params, orig_grad):
78
78
  p.grad = g
79
- vars.update = list(torch._foreach_sub(params_before_step, params)) # set update to difference between params
79
+ var.update = list(torch._foreach_sub(params_before_step, params)) # set update to difference between params
80
80
  for p, o in zip(params, params_before_step):
81
81
  p.set_(o) # pyright: ignore[reportArgumentType]
82
82
 
83
- return vars
83
+ return var
84
84
 
85
85
  def reset(self):
86
86
  super().reset()
@@ -0,0 +1,244 @@
1
+ from collections.abc import Callable
2
+ from functools import partial
3
+ from typing import Any, Literal
4
+
5
+ import directsearch
6
+ import numpy as np
7
+ import torch
8
+ from directsearch.ds import DEFAULT_PARAMS
9
+
10
+ from ...modules.second_order.newton import tikhonov_
11
+ from ...utils import Optimizer, TensorList
12
+
13
+
14
+ def _ensure_float(x):
15
+ if isinstance(x, torch.Tensor): return x.detach().cpu().item()
16
+ if isinstance(x, np.ndarray): return x.item()
17
+ return float(x)
18
+
19
+ def _ensure_numpy(x):
20
+ if isinstance(x, torch.Tensor): return x.detach().cpu()
21
+ if isinstance(x, np.ndarray): return x
22
+ return np.array(x)
23
+
24
+
25
+ Closure = Callable[[bool], Any]
26
+
27
+
28
+ class DirectSearch(Optimizer):
29
+ """Use directsearch as pytorch optimizer.
30
+
31
+ Note that this performs full minimization on each step,
32
+ so usually you would want to perform a single step, although performing multiple steps will refine the
33
+ solution.
34
+
35
+ Args:
36
+ params (_type_): _description_
37
+ maxevals (_type_, optional): _description_. Defaults to DEFAULT_PARAMS['maxevals'].
38
+ """
39
+ def __init__(
40
+ self,
41
+ params,
42
+ maxevals = DEFAULT_PARAMS['maxevals'], # Maximum number of function evaluations
43
+ rho = DEFAULT_PARAMS['rho'], # Forcing function
44
+ sketch_dim = DEFAULT_PARAMS['sketch_dim'], # Target dimension for sketching
45
+ sketch_type = DEFAULT_PARAMS['sketch_type'], # Sketching technique
46
+ poll_type = DEFAULT_PARAMS['poll_type'], # Polling direction type
47
+ alpha0 = DEFAULT_PARAMS['alpha0'], # Original stepsize value
48
+ alpha_max = DEFAULT_PARAMS['alpha_max'], # Maximum value for the stepsize
49
+ alpha_min = DEFAULT_PARAMS['alpha_min'], # Minimum value for the stepsize
50
+ gamma_inc = DEFAULT_PARAMS['gamma_inc'], # Increasing factor for the stepsize
51
+ gamma_dec = DEFAULT_PARAMS['gamma_dec'], # Decreasing factor for the stepsize
52
+ verbose = DEFAULT_PARAMS['verbose'], # Display information about the method
53
+ print_freq = DEFAULT_PARAMS['print_freq'], # How frequently to display information
54
+ use_stochastic_three_points = DEFAULT_PARAMS['use_stochastic_three_points'], # Boolean for a specific method
55
+ rho_uses_normd = DEFAULT_PARAMS['rho_uses_normd'], # Forcing function based on direction norm
56
+ ):
57
+ super().__init__(params, {})
58
+
59
+ kwargs = locals().copy()
60
+ del kwargs['self'], kwargs['params'], kwargs['__class__']
61
+ self._kwargs = kwargs
62
+
63
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
64
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
65
+ return _ensure_float(closure(False))
66
+
67
+ @torch.no_grad
68
+ def step(self, closure: Closure):
69
+ params = self.get_params()
70
+
71
+ x0 = params.to_vec().detach().cpu().numpy()
72
+
73
+ res = directsearch.solve(
74
+ partial(self._objective, params = params, closure = closure),
75
+ x0 = x0,
76
+ **self._kwargs
77
+ )
78
+
79
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
80
+ return res.f
81
+
82
+
83
+
84
+ class DirectSearchDS(Optimizer):
85
+ def __init__(
86
+ self,
87
+ params,
88
+ maxevals = DEFAULT_PARAMS['maxevals'], # Maximum number of function evaluations
89
+ rho = DEFAULT_PARAMS['rho'], # Forcing function
90
+ poll_type = DEFAULT_PARAMS['poll_type'], # Polling direction type
91
+ alpha0 = DEFAULT_PARAMS['alpha0'], # Original stepsize value
92
+ alpha_max = DEFAULT_PARAMS['alpha_max'], # Maximum value for the stepsize
93
+ alpha_min = DEFAULT_PARAMS['alpha_min'], # Minimum value for the stepsize
94
+ gamma_inc = DEFAULT_PARAMS['gamma_inc'], # Increasing factor for the stepsize
95
+ gamma_dec = DEFAULT_PARAMS['gamma_dec'], # Decreasing factor for the stepsize
96
+ verbose = DEFAULT_PARAMS['verbose'], # Display information about the method
97
+ print_freq = DEFAULT_PARAMS['print_freq'], # How frequently to display information
98
+ rho_uses_normd = DEFAULT_PARAMS['rho_uses_normd'], # Forcing function based on direction norm
99
+ ):
100
+ super().__init__(params, {})
101
+
102
+ kwargs = locals().copy()
103
+ del kwargs['self'], kwargs['params'], kwargs['__class__']
104
+ self._kwargs = kwargs
105
+
106
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
107
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
108
+ return _ensure_float(closure(False))
109
+
110
+ @torch.no_grad
111
+ def step(self, closure: Closure):
112
+ params = self.get_params()
113
+
114
+ x0 = params.to_vec().detach().cpu().numpy()
115
+
116
+ res = directsearch.solve_directsearch(
117
+ partial(self._objective, params = params, closure = closure),
118
+ x0 = x0,
119
+ **self._kwargs
120
+ )
121
+
122
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
123
+ return res.f
124
+
125
+ class DirectSearchProbabilistic(Optimizer):
126
+ def __init__(
127
+ self,
128
+ params,
129
+ maxevals = DEFAULT_PARAMS['maxevals'], # Maximum number of function evaluations
130
+ rho = DEFAULT_PARAMS['rho'], # Forcing function
131
+ alpha0 = DEFAULT_PARAMS['alpha0'], # Original stepsize value
132
+ alpha_max = DEFAULT_PARAMS['alpha_max'], # Maximum value for the stepsize
133
+ alpha_min = DEFAULT_PARAMS['alpha_min'], # Minimum value for the stepsize
134
+ gamma_inc = DEFAULT_PARAMS['gamma_inc'], # Increasing factor for the stepsize
135
+ gamma_dec = DEFAULT_PARAMS['gamma_dec'], # Decreasing factor for the stepsize
136
+ verbose = DEFAULT_PARAMS['verbose'], # Display information about the method
137
+ print_freq = DEFAULT_PARAMS['print_freq'], # How frequently to display information
138
+ rho_uses_normd = DEFAULT_PARAMS['rho_uses_normd'], # Forcing function based on direction norm
139
+ ):
140
+ super().__init__(params, {})
141
+
142
+ kwargs = locals().copy()
143
+ del kwargs['self'], kwargs['params'], kwargs['__class__']
144
+ self._kwargs = kwargs
145
+
146
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
147
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
148
+ return _ensure_float(closure(False))
149
+
150
+ @torch.no_grad
151
+ def step(self, closure: Closure):
152
+ params = self.get_params()
153
+
154
+ x0 = params.to_vec().detach().cpu().numpy()
155
+
156
+ res = directsearch.solve_probabilistic_directsearch(
157
+ partial(self._objective, params = params, closure = closure),
158
+ x0 = x0,
159
+ **self._kwargs
160
+ )
161
+
162
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
163
+ return res.f
164
+
165
+
166
+ class DirectSearchSubspace(Optimizer):
167
+ def __init__(
168
+ self,
169
+ params,
170
+ maxevals = DEFAULT_PARAMS['maxevals'], # Maximum number of function evaluations
171
+ rho = DEFAULT_PARAMS['rho'], # Forcing function
172
+ sketch_dim = DEFAULT_PARAMS['sketch_dim'], # Target dimension for sketching
173
+ sketch_type = DEFAULT_PARAMS['sketch_type'], # Sketching technique
174
+ poll_type = DEFAULT_PARAMS['poll_type'], # Polling direction type
175
+ alpha0 = DEFAULT_PARAMS['alpha0'], # Original stepsize value
176
+ alpha_max = DEFAULT_PARAMS['alpha_max'], # Maximum value for the stepsize
177
+ alpha_min = DEFAULT_PARAMS['alpha_min'], # Minimum value for the stepsize
178
+ gamma_inc = DEFAULT_PARAMS['gamma_inc'], # Increasing factor for the stepsize
179
+ gamma_dec = DEFAULT_PARAMS['gamma_dec'], # Decreasing factor for the stepsize
180
+ verbose = DEFAULT_PARAMS['verbose'], # Display information about the method
181
+ print_freq = DEFAULT_PARAMS['print_freq'], # How frequently to display information
182
+ rho_uses_normd = DEFAULT_PARAMS['rho_uses_normd'], # Forcing function based on direction norm
183
+ ):
184
+ super().__init__(params, {})
185
+
186
+ kwargs = locals().copy()
187
+ del kwargs['self'], kwargs['params'], kwargs['__class__']
188
+ self._kwargs = kwargs
189
+
190
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
191
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
192
+ return _ensure_float(closure(False))
193
+
194
+ @torch.no_grad
195
+ def step(self, closure: Closure):
196
+ params = self.get_params()
197
+
198
+ x0 = params.to_vec().detach().cpu().numpy()
199
+
200
+ res = directsearch.solve_subspace_directsearch(
201
+ partial(self._objective, params = params, closure = closure),
202
+ x0 = x0,
203
+ **self._kwargs
204
+ )
205
+
206
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
207
+ return res.f
208
+
209
+
210
+
211
+ class DirectSearchSTP(Optimizer):
212
+ def __init__(
213
+ self,
214
+ params,
215
+ maxevals = DEFAULT_PARAMS['maxevals'], # Maximum number of function evaluations
216
+ alpha0 = DEFAULT_PARAMS['alpha0'], # Original stepsize value
217
+ alpha_min = DEFAULT_PARAMS['alpha_min'], # Minimum value for the stepsize
218
+ verbose = DEFAULT_PARAMS['verbose'], # Display information about the method
219
+ print_freq = DEFAULT_PARAMS['print_freq'], # How frequently to display information
220
+ ):
221
+ super().__init__(params, {})
222
+
223
+ kwargs = locals().copy()
224
+ del kwargs['self'], kwargs['params'], kwargs['__class__']
225
+ self._kwargs = kwargs
226
+
227
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
228
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
229
+ return _ensure_float(closure(False))
230
+
231
+ @torch.no_grad
232
+ def step(self, closure: Closure):
233
+ params = self.get_params()
234
+
235
+ x0 = params.to_vec().detach().cpu().numpy()
236
+
237
+ res = directsearch.solve_stp(
238
+ partial(self._objective, params = params, closure = closure),
239
+ x0 = x0,
240
+ **self._kwargs
241
+ )
242
+
243
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
244
+ return res.f