torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -0,0 +1,164 @@
1
+ import warnings
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import Callable, Sequence
4
+ from functools import partial
5
+ from typing import Literal
6
+
7
+ import torch
8
+
9
+ from ...core import Modular, Module, Vars
10
+ from ...utils import NumberList, TensorList
11
+ from ...utils.derivatives import jacobian_wrt
12
+ from ..grad_approximation import GradApproximator, GradTarget
13
+
14
+
15
+ class Reformulation(Module, ABC):
16
+ def __init__(self, defaults):
17
+ super().__init__(defaults)
18
+
19
+ @abstractmethod
20
+ def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], vars: Vars) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
21
+ """returns loss and gradient, if backward is False then gradient can be None"""
22
+
23
+ def pre_step(self, vars: Vars) -> Vars | None:
24
+ """This runs once before each step, whereas `closure` may run multiple times per step if further modules
25
+ evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
26
+ return vars
27
+
28
+ def step(self, vars):
29
+ ret = self.pre_step(vars)
30
+ if isinstance(ret, Vars): vars = ret
31
+
32
+ if vars.closure is None: raise RuntimeError("Reformulation requires closure")
33
+ params, closure = vars.params, vars.closure
34
+
35
+
36
+ def modified_closure(backward=True):
37
+ loss, grad = self.closure(backward, closure, params, vars)
38
+
39
+ if grad is not None:
40
+ for p,g in zip(params, grad):
41
+ p.grad = g
42
+
43
+ return loss
44
+
45
+ vars.closure = modified_closure
46
+ return vars
47
+
48
+
49
+ def _decay_sigma_(self: Module, params):
50
+ for p in params:
51
+ state = self.state[p]
52
+ settings = self.settings[p]
53
+ state['sigma'] *= settings['decay']
54
+
55
+ def _generate_perturbations_to_state_(self: Module, params: TensorList, n_samples, sigmas, generator):
56
+ perturbations = [params.sample_like(generator=generator) for _ in range(n_samples)]
57
+ torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in sigmas for v in [vv]*n_samples])
58
+ for param, prt in zip(params, zip(*perturbations)):
59
+ self.state[param]['perturbations'] = prt
60
+
61
+ def _clear_state_hook(optimizer: Modular, vars: Vars, self: Module):
62
+ for m in optimizer.unrolled_modules:
63
+ if m is not self:
64
+ m.reset()
65
+
66
+ class GaussianHomotopy(Reformulation):
67
+ def __init__(
68
+ self,
69
+ n_samples: int,
70
+ init_sigma: float,
71
+ tol: float | None = 1e-4,
72
+ decay=0.5,
73
+ max_steps: int | None = None,
74
+ clear_state=True,
75
+ seed: int | None = None,
76
+ ):
77
+ defaults = dict(n_samples=n_samples, init_sigma=init_sigma, tol=tol, decay=decay, max_steps=max_steps, clear_state=clear_state, seed=seed)
78
+ super().__init__(defaults)
79
+
80
+
81
+ def _get_generator(self, seed: int | None | torch.Generator, params: list[torch.Tensor]):
82
+ if 'generator' not in self.global_state:
83
+ if isinstance(seed, torch.Generator): self.global_state['generator'] = seed
84
+ elif seed is not None: self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
85
+ else: self.global_state['generator'] = None
86
+ return self.global_state['generator']
87
+
88
+ def pre_step(self, vars):
89
+ params = TensorList(vars.params)
90
+ settings = self.settings[params[0]]
91
+ n_samples = settings['n_samples']
92
+ init_sigma = self.get_settings('init_sigma', params=params)
93
+ sigmas = self.get_state('sigma', params = params, init=init_sigma)
94
+
95
+ if any('perturbations' not in self.state[p] for p in params):
96
+ generator = self._get_generator(settings['seed'], params)
97
+ _generate_perturbations_to_state_(self, params=params, n_samples=n_samples, sigmas=sigmas, generator=generator)
98
+
99
+ # sigma decay rules
100
+ max_steps = settings['max_steps']
101
+ decayed = False
102
+ if max_steps is not None and max_steps > 0:
103
+ level_steps = self.global_state['level_steps'] = self.global_state.get('level_steps', 0) + 1
104
+ if level_steps > max_steps:
105
+ self.global_state['level_steps'] = 0
106
+ _decay_sigma_(self, params)
107
+ decayed = True
108
+
109
+ tol = settings['tol']
110
+ if tol is not None and not decayed:
111
+ if not any('prev_params' in self.state[p] for p in params):
112
+ prev_params = self.get_state('prev_params', params=params, cls=TensorList, init='param')
113
+ else:
114
+ prev_params = self.get_state('prev_params', params=params, cls=TensorList, init='param')
115
+ s = params - prev_params
116
+
117
+ if s.abs().global_max() <= tol:
118
+ _decay_sigma_(self, params)
119
+ decayed = True
120
+
121
+ prev_params.copy_(params)
122
+
123
+ if decayed:
124
+ generator = self._get_generator(settings['seed'], params)
125
+ _generate_perturbations_to_state_(self, params=params, n_samples=n_samples, sigmas=sigmas, generator=generator)
126
+ if settings['clear_state']:
127
+ vars.post_step_hooks.append(partial(_clear_state_hook, self=self))
128
+
129
+ @torch.no_grad
130
+ def closure(self, backward, closure, params, vars):
131
+ params = TensorList(params)
132
+
133
+ settings = self.settings[params[0]]
134
+ n_samples = settings['n_samples']
135
+
136
+ perturbations = list(zip(*(self.state[p]['perturbations'] for p in params)))
137
+
138
+ loss = None
139
+ grad = None
140
+ for i in range(n_samples):
141
+ prt = perturbations[i]
142
+
143
+ params.add_(prt)
144
+ if backward:
145
+ with torch.enable_grad(): l = closure()
146
+ if grad is None: grad = params.grad
147
+ else: grad += params.grad
148
+
149
+ else:
150
+ l = closure(False)
151
+
152
+ if loss is None: loss = l
153
+ else: loss = loss+l
154
+
155
+ params.sub_(prt)
156
+
157
+ assert loss is not None
158
+ if n_samples > 1:
159
+ loss = loss / n_samples
160
+ if backward:
161
+ assert grad is not None
162
+ grad.div_(n_samples)
163
+
164
+ return loss, grad
@@ -1,128 +1,115 @@
1
- from typing import Literal
2
- from collections.abc import Iterable
3
-
4
- import torch
5
-
6
- from ...tensorlist import TensorList
7
- from ...core import OptimizerModule
8
-
9
-
10
- def vector_laplacian_smoothing(input: torch.Tensor, sigma: float = 1) -> torch.Tensor:
11
- """Returns a new vector with laplacian smoothing applied to it. This flattens the input!"""
12
- vec = input.view(-1)
13
- v = torch.zeros_like(vec)
14
- v[0] = -2
15
- v[1] = 1
16
- v[-1] = 1
17
- numerator = torch.fft.fft(vec) # pylint: disable = not-callable
18
- denominator = 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
19
- return torch.fft.ifft(numerator / denominator).real # pylint: disable = not-callable
20
-
21
- def gradient_laplacian_smoothing_(params: Iterable[torch.Tensor], sigma: float = 1, layerwise=True, min_numel = 4):
22
- """Applies laplacian smoothing to gradients of an iterable of parameters.
23
-
24
- This updates gradients in-place.
25
-
26
- Args:
27
- params (abc.Iterable[torch.Tensor]): an iterable of Tensors that will have gradients smoothed.
28
- sigma (float, optional): controls the amount of smoothing. Defaults to 1.
29
- layerwise (bool, optional):
30
- If True, applies smoothing to each parameter's gradient separately,
31
- Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.
32
- min_numel (int, optional):
33
- minimum number of elements in a parameter to apply laplacian smoothing to.
34
- Only has effect if `layerwise` is True. Defaults to 4.
35
-
36
- Reference:
37
- *Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
38
- Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
39
- """
40
- grads = TensorList(params).get_existing_grads()
41
- if layerwise:
42
- for g in grads:
43
- if g.numel() >= min_numel:
44
- g.set_(vector_laplacian_smoothing(g, sigma).reshape(g.shape)) # type:ignore
45
- else:
46
- vec = grads.to_vec()
47
- grads.from_vec_(vector_laplacian_smoothing(vec, sigma))
48
-
49
-
50
- def _precompute_denominator(tensor: torch.Tensor, sigma) -> torch.Tensor:
51
- """Denominator will always be the same and depends on the size of the vector and the sigma."""
52
- v = torch.zeros_like(tensor.view(-1))
53
- v[0] = -2
54
- v[1] = 1
55
- v[-1] = 1
56
- return 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
57
-
58
- class LaplacianSmoothing(OptimizerModule):
59
- """Applies laplacian smoothing via a fast Fourier transform solver.
60
-
61
- Args:
62
- sigma (float, optional): controls the amount of smoothing. Defaults to 1.
63
- layerwise (bool, optional):
64
- If True, applies smoothing to each parameter's gradient separately,
65
- Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.
66
- min_numel (int, optional):
67
- minimum number of elements in a parameter to apply laplacian smoothing to.
68
- Only has effect if `layerwise` is True. Defaults to 4.
69
- target (str, optional):
70
- determines what this module updates.
71
-
72
- "ascent" - it updates the ascent (default).
73
-
74
- "grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
75
-
76
- "closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
77
-
78
- Reference:
79
- *Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
80
- Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
81
-
82
- """
83
- def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4, target: Literal['ascent', 'grad', 'closure',] = 'ascent'):
84
- # sigma from defaults is used in layerwise case
85
- # otherwise self.sigma is used
86
- defaults = dict(sigma = sigma)
87
- self.sigma = 1
88
- super().__init__(defaults, target=target)
89
- self.layerwise = layerwise
90
- self.min_numel = min_numel
91
-
92
- # precomputed denominator for when layerwise=False
93
- self.full_denominator = None
94
-
95
-
96
- @torch.no_grad
97
- def _update(self, vars, ascent):
98
- params = self.get_params()
99
- sigmas = self.get_group_key('sigma')
100
-
101
- # layerwise laplacian smoothing
102
- if self.layerwise:
103
-
104
- # precompute the denominator for each layer and store it in each parameters state
105
- denominators = TensorList()
106
- for p, σ in zip(params, sigmas):
107
- if p.numel() > self.min_numel:
108
- den = self.state[p]
109
- if 'denominator' not in den: den['denominator'] = _precompute_denominator(p, σ)
110
- denominators.append(den['denominator'])
111
-
112
- # apply the smoothing
113
- smoothed_direction = TensorList()
114
- for g, σ, den in zip(ascent, sigmas, denominators):
115
- smoothed_direction.append(torch.fft.ifft(torch.fft.fft(g.view(-1)) / den).real.reshape(g.shape)) # pylint: disable = not-callable
116
- return smoothed_direction
117
-
118
- # else
119
- # full laplacian smoothing
120
- # precompute full denominator
121
- if self.full_denominator is None:
122
- self.full_denominator = _precompute_denominator(ascent.to_vec(), self.sigma)
123
-
124
- # apply the smoothing
125
- vec = ascent.to_vec()
126
- return ascent.from_vec(torch.fft.ifft(torch.fft.fft(vec) / self.full_denominator).real) # pylint: disable = not-callable
127
-
128
-
1
+ from typing import Literal
2
+ from collections.abc import Iterable
3
+
4
+ import torch
5
+
6
+ from ...utils.tensorlist import TensorList
7
+ from ...core import Transform, Target
8
+
9
+
10
+ def vector_laplacian_smoothing(input: torch.Tensor, sigma: float = 1) -> torch.Tensor:
11
+ """Returns a new vector with laplacian smoothing applied to it. This flattens the input!"""
12
+ vec = input.view(-1)
13
+ v = torch.zeros_like(vec)
14
+ v[0] = -2
15
+ v[1] = 1
16
+ v[-1] = 1
17
+ numerator = torch.fft.fft(vec) # pylint: disable = not-callable
18
+ denominator = 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
19
+ return torch.fft.ifft(numerator / denominator).real # pylint: disable = not-callable
20
+
21
+ def gradient_laplacian_smoothing_(params: Iterable[torch.Tensor], sigma: float = 1, layerwise=True, min_numel = 4):
22
+ """Applies laplacian smoothing to gradients of an iterable of parameters.
23
+
24
+ This updates gradients in-place.
25
+
26
+ Args:
27
+ params (abc.Iterable[torch.Tensor]): an iterable of Tensors that will have gradients smoothed.
28
+ sigma (float, optional): controls the amount of smoothing. Defaults to 1.
29
+ layerwise (bool, optional):
30
+ If True, applies smoothing to each parameter's gradient separately,
31
+ Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.
32
+ min_numel (int, optional):
33
+ minimum number of elements in a parameter to apply laplacian smoothing to.
34
+ Only has effect if `layerwise` is True. Defaults to 4.
35
+
36
+ Reference:
37
+ *Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
38
+ Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
39
+ """
40
+ grads = TensorList(params).get_grad()
41
+ if layerwise:
42
+ for g in grads:
43
+ if g.numel() >= min_numel:
44
+ g.set_(vector_laplacian_smoothing(g, sigma).view_as(g)) # pyright:ignore[reportArgumentType]
45
+ else:
46
+ vec = grads.to_vec()
47
+ grads.from_vec_(vector_laplacian_smoothing(vec, sigma))
48
+
49
+
50
+ def _precompute_denominator(tensor: torch.Tensor, sigma) -> torch.Tensor:
51
+ """Denominator will always be the same and depends on the size of the vector and the sigma."""
52
+ v = torch.zeros_like(tensor.view(-1))
53
+ v[0] = -2
54
+ v[1] = 1
55
+ v[-1] = 1
56
+ return 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
57
+
58
+ class LaplacianSmoothing(Transform):
59
+ """Applies laplacian smoothing via a fast Fourier transform solver.
60
+
61
+ Args:
62
+ sigma (float, optional): controls the amount of smoothing. Defaults to 1.
63
+ layerwise (bool, optional):
64
+ If True, applies smoothing to each parameter's gradient separately,
65
+ Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.
66
+ min_numel (int, optional):
67
+ minimum number of elements in a parameter to apply laplacian smoothing to.
68
+ Only has effect if `layerwise` is True. Defaults to 4.
69
+ target (str, optional):
70
+ what to set on vars.
71
+
72
+ Reference:
73
+ *Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
74
+ Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
75
+
76
+ """
77
+ def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4, target: Target = 'update'):
78
+ defaults = dict(sigma = sigma, layerwise=layerwise, min_numel=min_numel)
79
+ super().__init__(defaults, uses_grad=False, target=target)
80
+ # precomputed denominator for when layerwise=False
81
+ self.global_state['full_denominator'] = None
82
+
83
+
84
+ @torch.no_grad
85
+ def transform(self, tensors, params, grads, vars):
86
+ layerwise = self.settings[params[0]]['layerwise']
87
+
88
+ # layerwise laplacian smoothing
89
+ if layerwise:
90
+
91
+ # precompute the denominator for each layer and store it in each parameters state
92
+ smoothed_target = TensorList()
93
+ for p, t in zip(params, tensors):
94
+ settings = self.settings[p]
95
+ if p.numel() > settings['min_numel']:
96
+ state = self.state[p]
97
+ if 'denominator' not in state: state['denominator'] = _precompute_denominator(p, settings['sigma'])
98
+ smoothed_target.append(torch.fft.ifft(torch.fft.fft(t.view(-1)) / state['denominator']).real.view_as(t)) #pylint:disable=not-callable
99
+ else:
100
+ smoothed_target.append(t)
101
+
102
+ return smoothed_target
103
+
104
+ # else
105
+ # full laplacian smoothing
106
+ # precompute full denominator
107
+ tensors = TensorList(tensors)
108
+ if self.global_state.get('full_denominator', None) is None:
109
+ self.global_state['full_denominator'] = _precompute_denominator(tensors.to_vec(), self.settings[params[0]]['sigma'])
110
+
111
+ # apply the smoothing
112
+ vec = tensors.to_vec()
113
+ return tensors.from_vec(torch.fft.ifft(torch.fft.fft(vec) / self.global_state['full_denominator']).real)#pylint:disable=not-callable
114
+
115
+
@@ -0,0 +1 @@
1
+ from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_
@@ -0,0 +1,52 @@
1
+ from collections.abc import Iterable, Sequence
2
+
3
+ import torch
4
+
5
+ from ...core import Module, Target, Transform
6
+ from ...utils import NumberList, TensorList, as_tensorlist
7
+
8
+ @torch.no_grad
9
+ def weight_decay_(
10
+ grad_: TensorList,
11
+ params: TensorList,
12
+ weight_decay: float | NumberList,
13
+ ord: int = 2
14
+ ):
15
+ """returns `grad_`."""
16
+ if ord == 1: return grad_.add_(params.sign().mul_(weight_decay))
17
+ if ord == 2: return grad_.add_(params.mul(weight_decay))
18
+ if ord - 1 % 2 != 0: return grad_.add_(params.pow(ord-1).mul_(weight_decay))
19
+ return grad_.add_(params.pow(ord-1).copysign_(params).mul_(weight_decay))
20
+
21
+
22
+ class WeightDecay(Transform):
23
+ def __init__(self, weight_decay: float, ord: int = 2, target: Target = 'update'):
24
+ defaults = dict(weight_decay=weight_decay, ord=ord)
25
+ super().__init__(defaults, uses_grad=False, target=target)
26
+
27
+ @torch.no_grad
28
+ def transform(self, tensors, params, grads, vars):
29
+ weight_decay = self.get_settings('weight_decay', params=params, cls=NumberList)
30
+ ord = self.settings[params[0]]['ord']
31
+
32
+ return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)
33
+
34
+ @torch.no_grad
35
+ def decay_weights_(params: Iterable[torch.Tensor], weight_decay: float | NumberList, ord:int=2):
36
+ """directly decays weights in-place"""
37
+ params = TensorList(params)
38
+ weight_decay_(params, params, -weight_decay, ord)
39
+
40
+ class DirectWeightDecay(Module):
41
+ """directly decays weights in-place"""
42
+ def __init__(self, weight_decay: float, ord: int = 2,):
43
+ defaults = dict(weight_decay=weight_decay, ord=ord)
44
+ super().__init__(defaults)
45
+
46
+ @torch.no_grad
47
+ def step(self, vars):
48
+ weight_decay = self.get_settings('weight_decay', params=vars.params, cls=NumberList)
49
+ ord = self.settings[vars.params[0]]['ord']
50
+
51
+ decay_weights_(vars.params, weight_decay, ord)
52
+ return vars
@@ -0,0 +1 @@
1
+ from .optim_wrapper import Wrap
@@ -0,0 +1,91 @@
1
+ from collections.abc import Iterable, Mapping, Sequence, Callable
2
+ from typing import Any
3
+ import torch
4
+
5
+ from ...core.module import Module
6
+ from ...utils import Params, _copy_param_groups, _make_param_groups
7
+
8
+
9
+ class Wrap(Module):
10
+ """Custom param groups are supported only by `set_param_groups`. Settings passed to Modular will be ignored."""
11
+ def __init__(self, opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer, *args, **kwargs):
12
+ super().__init__()
13
+ self._opt_fn = opt_fn
14
+ self._opt_args = args
15
+ self._opt_kwargs = kwargs
16
+ self._custom_param_groups = None
17
+
18
+ self.optimizer: torch.optim.Optimizer | None = None
19
+ if isinstance(self._opt_fn, torch.optim.Optimizer) or not callable(self._opt_fn):
20
+ self.optimizer = self._opt_fn
21
+
22
+ def set_param_groups(self, param_groups):
23
+ self._custom_param_groups = param_groups
24
+ return super().set_param_groups(param_groups)
25
+
26
+ @torch.no_grad
27
+ def step(self, vars):
28
+ params = vars.params
29
+
30
+ # initialize opt on 1st step
31
+ if self.optimizer is None:
32
+ assert callable(self._opt_fn)
33
+ param_groups = params if self._custom_param_groups is None else self._custom_param_groups
34
+ self.optimizer = self._opt_fn(param_groups, *self._opt_args, **self._opt_kwargs)
35
+
36
+ # set grad to update
37
+ orig_grad = [p.grad for p in params]
38
+ for p, u in zip(params, vars.get_update()):
39
+ p.grad = u
40
+
41
+ # if this module is last, can step with _opt directly
42
+ # direct step can't be applied if next module is LR but _opt doesn't support lr,
43
+ # and if there are multiple different per-parameter lrs (would be annoying to support)
44
+ if vars.is_last and (
45
+ (vars.last_module_lrs is None)
46
+ or
47
+ (('lr' in self.optimizer.defaults) and (len(set(vars.last_module_lrs)) == 1))
48
+ ):
49
+ lr = 1 if vars.last_module_lrs is None else vars.last_module_lrs[0]
50
+
51
+ # update optimizer lr with desired lr
52
+ if lr != 1:
53
+ self.optimizer.defaults['__original_lr__'] = self.optimizer.defaults['lr']
54
+ for g in self.optimizer.param_groups:
55
+ g['__original_lr__'] = g['lr']
56
+ g['lr'] = g['lr'] * lr
57
+
58
+ # step
59
+ self.optimizer.step()
60
+
61
+ # restore original lr
62
+ if lr != 1:
63
+ self.optimizer.defaults['lr'] = self.optimizer.defaults.pop('__original_lr__')
64
+ for g in self.optimizer.param_groups:
65
+ g['lr'] = g.pop('__original_lr__')
66
+
67
+ # restore grad
68
+ for p, g in zip(params, orig_grad):
69
+ p.grad = g
70
+
71
+ vars.stop = True; vars.skip_update = True
72
+ return vars
73
+
74
+ # this is not the last module, meaning update is difference in parameters
75
+ params_before_step = [p.clone() for p in params]
76
+ self.optimizer.step() # step and update params
77
+ for p, g in zip(params, orig_grad):
78
+ p.grad = g
79
+ vars.update = list(torch._foreach_sub(params_before_step, params)) # set update to difference between params
80
+ for p, o in zip(params, params_before_step):
81
+ p.set_(o) # pyright: ignore[reportArgumentType]
82
+
83
+ return vars
84
+
85
+ def reset(self):
86
+ super().reset()
87
+ assert self.optimizer is not None
88
+ for g in self.optimizer.param_groups:
89
+ for p in g['params']:
90
+ state = self.optimizer.state[p]
91
+ state.clear()
@@ -1,10 +1,2 @@
1
- r"""
2
- Ready to use optimizers.
3
- """
4
- from .modular import Modular
5
- from .quasi_newton import *
6
- from .zeroth_order import *
7
- from .second_order import *
8
- from .first_order import *
9
- # from .wrappers.scipy import ScipyMinimize
10
- from . import experimental
1
+ from .utility import *
2
+ from .wrappers import *
@@ -0,0 +1 @@
1
+ from .split import Split
@@ -0,0 +1,45 @@
1
+ import warnings
2
+ from collections.abc import Callable, Iterable
3
+
4
+ import torch
5
+
6
+ from ...utils import flatten, get_params
7
+
8
+ class Split(torch.optim.Optimizer):
9
+ """Steps will all `optimizers`, also has a check that they have no duplicate parameters.
10
+ Doesn't support closure based optimizers.
11
+
12
+ Example:
13
+
14
+ .. code:: py
15
+
16
+ opt = Split(
17
+ torch.optim.Adam(model.encoder.parameters(), lr=0.001),
18
+ torch.optim.SGD(model.decoder.parameters(), lr=0.1)
19
+ )
20
+ """
21
+ def __init__(self, *optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer]):
22
+ all_params = []
23
+ self.optimizers: list[torch.optim.Optimizer] = flatten(optimizers)
24
+
25
+ # gather all params in case user tries to access them from this object
26
+ for i,opt in enumerate(self.optimizers):
27
+ for p in get_params(opt.param_groups, 'all', list):
28
+ if p not in all_params: all_params.append(p)
29
+ else: warnings.warn(
30
+ f'optimizers[{i}] {opt.__class__.__name__} has some duplicate parameters '
31
+ 'that are also in previous optimizers. They will be updated multiple times.')
32
+
33
+ super().__init__(all_params, {})
34
+
35
+ def step(self, closure: Callable | None = None):
36
+ loss = None
37
+
38
+ # if closure provided, populate grad, otherwise each optimizer will call closure separately
39
+ if closure is not None:
40
+ with torch.enable_grad(): loss = closure()
41
+
42
+ for opt in self.optimizers:
43
+ opt.step() # closure not passed as grad is already evaluated
44
+
45
+ return loss
@@ -6,7 +6,7 @@ import torch
6
6
 
7
7
  import nevergrad as ng
8
8
 
9
- from ...core import TensorListOptimizer
9
+ from ...utils import Optimizer
10
10
 
11
11
 
12
12
  def _ensure_float(x):
@@ -14,7 +14,7 @@ def _ensure_float(x):
14
14
  if isinstance(x, np.ndarray): return x.item()
15
15
  return float(x)
16
16
 
17
- class NevergradOptimizer(TensorListOptimizer):
17
+ class NevergradOptimizer(Optimizer):
18
18
  """Use nevergrad optimizer as pytorch optimizer.
19
19
  Note that it is recommended to specify `budget` to the number of iterations you expect to run,
20
20
  as some nevergrad optimizers will error without it.
@@ -85,29 +85,3 @@ class NevergradOptimizer(TensorListOptimizer):
85
85
  loss = closure(False)
86
86
  self.opt.tell(x, _ensure_float(loss))
87
87
  return loss
88
-
89
-
90
-
91
- # class NevergradSubspace(ModularOptimizer):
92
- # def __init__(
93
- # self,
94
- # params,
95
- # opt_cls:"type[ng.optimizers.base.Optimizer] | abc.Callable[..., ng.optimizers.base.Optimizer]",
96
- # budget=None,
97
- # mutable_sigma = False,
98
- # use_init = True,
99
- # projections = Proj2Masks(5),
100
- # ):
101
-
102
- # modules = [
103
- # Subspace(projections, update_every=100),
104
- # UninitializedClosureOptimizerWrapper(
105
- # NevergradOptimizer,
106
- # opt_cls = opt_cls,
107
- # budget = budget,
108
- # mutable_sigma = mutable_sigma,
109
- # use_init = use_init,
110
- # ),
111
- # ]
112
-
113
- # super().__init__(params, modules)