torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -1,148 +0,0 @@
1
- from collections import abc
2
- import warnings
3
- from inspect import cleandoc
4
- import torch
5
- from typing import Any
6
-
7
- from ..core import OptimizerModule, TensorListOptimizer, OptimizationVars, _Chain, _Chainable
8
- from ..utils.python_tools import flatten
9
-
10
- def _unroll_modules(flat_modules: list[OptimizerModule], nested) -> list[OptimizerModule]:
11
- """returns a list of all modules, including all nested ones"""
12
- unrolled = []
13
- for m in flat_modules:
14
- unrolled.append(m)
15
- if len(m.children) > 0:
16
- unrolled.extend(_unroll_modules(list(m.children.values()), nested=True))
17
- if nested:
18
- if m.next_module is not None:
19
- unrolled.extend(_unroll_modules([m.next_module], nested=True))
20
- return unrolled
21
-
22
-
23
- class Modular(TensorListOptimizer):
24
- """Creates a modular optimizer by chaining together a sequence of optimizer modules.
25
-
26
- Args:
27
- params: iterable of parameters to optimize or dicts defining parameter groups.
28
- *modules (Iterable[OptimizerModule] | OptimizerModule):
29
- A sequence of optimizer modules to chain together. This argument will be flattened."""
30
- def __init__(self, params, *modules: _Chainable):
31
- flat_modules = flatten(modules)
32
- self.modules: list[OptimizerModule] = flat_modules
33
- self.chain = _Chain(flat_modules)
34
-
35
- # save unrolled modules and make sure there is only 1 LR module.
36
- self.unrolled_modules = _unroll_modules(flat_modules, nested=False)
37
- num_lr_modules = len([m for m in self.unrolled_modules if m.IS_LR_MODULE])
38
- if num_lr_modules > 1:
39
- warnings.warn(cleandoc(
40
- f"""More then 1 lr modules have been added.
41
- This may lead to incorrect behaviour with learning rate scheduling and per-parameter learning rates.
42
- Make sure there is a single `LR` module, use `Alpha` module instead of it where needed.
43
- \nList of modules: {self.unrolled_modules}; \nlist of lr modules: {[m for m in self.unrolled_modules if m.IS_LR_MODULE]}"""
44
- ))
45
-
46
- if isinstance(params, torch.nn.Module):
47
- self.model = params
48
- params = list(params.parameters())
49
- else:
50
- self.model = None
51
- params = list(params)
52
-
53
- # if there is an `lr` setting, make sure there is an LR module that can use it
54
- for p in params:
55
- if isinstance(p, dict):
56
- if 'lr' in p:
57
- if num_lr_modules == 0:
58
- warnings.warn(cleandoc(
59
- """Passed "lr" setting in a parameter group, but there is no LR module that can use that setting.
60
- Add an `LR` module to make per-layer "lr" setting work."""
61
- ))
62
-
63
- super().__init__(params, {})
64
- self.chain._initialize_(params, set_passed_params=True)
65
-
66
- # run post-init hooks
67
- for module in self.unrolled_modules:
68
- for hook in module.post_init_hooks:
69
- hook(self, module)
70
-
71
- def state_dict(self):
72
- state_dict = {}
73
- state_dict['__self__'] = super().state_dict()
74
- for i,v in enumerate(self.unrolled_modules):
75
- state_dict[str(i)] = v.state_dict()
76
- return state_dict
77
-
78
- def load_state_dict(self, state_dict: dict[str, Any]) -> None:
79
- super().load_state_dict(state_dict['__self__'])
80
- for i,v in enumerate(self.unrolled_modules):
81
- if str(i) in state_dict:
82
- v.load_state_dict(state_dict[str(i)])
83
- else:
84
- warnings.warn(f"Tried to load state dict for {i}th module: {v.__class__.__name__}, but it is not present in state_dict with {list(state_dict.keys()) = }")
85
-
86
- def get_lr_module(self, last=True) -> OptimizerModule:
87
- """
88
- Retrieves the module in the chain that controls the learning rate.
89
-
90
- This method is useful for setting up a learning rate scheduler. By default, it retrieves the last module in the chain
91
- that has an `lr` group parameter.
92
-
93
- Args:
94
- last (bool, optional):
95
- If multiple modules have an `lr` parameter, this argument controls which one is returned.
96
- - If `True` (default), the last module is returned.
97
- - If `False`, the first module is returned.
98
-
99
- Returns:
100
- OptimizerModule: The module that controls the learning rate.
101
-
102
- Raises:
103
- ValueError: If no modules in the chain have an `lr` parameter. To fix this, add an `LR` module.
104
-
105
- Example:
106
-
107
- .. code:: py
108
- from torch.optim.lr_scheduler import OneCycleLR
109
- import torchzero as tz
110
-
111
- opt = tz.Modular(model.parameters(), [tz.m.RMSProp(), tz.m.LR(1e-2), tz.m.DirectionalNewton()])
112
- lr_scheduler = OneCycleLR(opt.get_lr_module(), max_lr = 1e-1, total_steps = 1000, cycle_momentum=False)
113
-
114
- """
115
- modules = list(reversed(self.unrolled_modules)) if last else self.unrolled_modules
116
- for m in modules:
117
- if 'lr' in m.param_groups[0]: return m
118
-
119
- raise ValueError(f'No modules out of {", ".join(m.__class__.__name__ for m in modules)} support and `lr` parameter. The easiest way to fix is is to add an `LR(1)` module at the end.')
120
-
121
- def get_module_by_name(self, name: str | type, last=True) -> OptimizerModule:
122
- """Returns the first or last module in the chain that matches the provided name or type.
123
-
124
- Args:
125
- name (str | type): the name (as a string) or the type of the module to search for.
126
- last (bool, optional):
127
- If multiple modules match, this argument controls which one is returned.
128
- - If `True` (default), the last matching module is returned.
129
- - If `False`, the first matching module is returned.
130
-
131
- Returns:
132
- OptimizerModule: The matching optimizer module.
133
-
134
- Raises:
135
- ValueError: If no modules in the chain match the provided name or type.
136
- """
137
- modules = list(reversed(self.unrolled_modules)) if last else self.unrolled_modules
138
- for m in modules:
139
- if isinstance(name, str) and m.__class__.__name__ == name: return m
140
- if isinstance(name, type) and isinstance(m, name): return m
141
-
142
- raise ValueError(f'No modules out of {", ".join(m.__class__.__name__ for m in modules)} match "{name}".')
143
-
144
- def step(self, closure=None): # type:ignore
145
- vars = OptimizationVars(closure, self.model)
146
- res = self.chain.step(vars)
147
- for hook in vars.post_step_hooks: hook(self, vars)
148
- return res
@@ -1 +0,0 @@
1
- from .directional_newton import DirectionalNewton
@@ -1,58 +0,0 @@
1
- from ...modules import (
2
- SGD,
3
- )
4
- from ...modules import DirectionalNewton as _DirectionalNewton, LR
5
- from ..modular import Modular
6
-
7
-
8
- class DirectionalNewton(Modular):
9
- """Minimizes a parabola in the direction of the gradient (or update if momentum or weight decay is enabled)
10
- via one additional forward pass, and uses another forward pass to make sure it didn't overstep.
11
- So in total this performs three forward passes and one backward.
12
-
13
- First forward and backward pass is used to calculate the value and gradient at initial parameters.
14
- Then a gradient descent step is performed with `lr` learning rate, and loss is recalculated
15
- with new parameters. A quadratic is fitted to two points and gradient,
16
- if it has positive curvature, this makes a step towards the minimum, and checks if lr decreased
17
- with an additional forward pass.
18
-
19
- Args:
20
- params: iterable of parameters to optimize or dicts defining parameter groups.
21
- lr (float, optional):
22
- learning rate. Since you shouldn't put this module after LR(), you have to specify
23
- the learning rate in this argument. Defaults to 1e-2.
24
- max_dist (float | None, optional):
25
- maximum distance to step when minimizing quadratic.
26
- If minimum is further than this distance, minimization is not performed. Defaults to 1e4.
27
- validate_step (bool, optional):
28
- uses an additional forward pass to check
29
- if step towards the minimum actually decreased the loss. Defaults to True.
30
- momentum (float, optional): momentum. Defaults to 0.
31
- dampening (float, optional): momentum dampening. Defaults to 0.
32
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
33
- nesterov (bool, optional):
34
- enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
35
-
36
- Note:
37
- While lr scheduling is supported, this uses lr of the first parameter for all parameters.
38
- """
39
- def __init__(
40
- self,
41
- params,
42
- lr: float = 1e-4,
43
- max_dist: float | None = 1e5,
44
- validate_step: bool = True,
45
- momentum: float = 0,
46
- dampening: float = 0,
47
- weight_decay: float = 0,
48
- nesterov: bool = False,
49
-
50
- ):
51
-
52
- modules = [
53
- SGD(momentum=momentum,dampening=dampening,weight_decay=weight_decay,nesterov=nesterov),
54
- LR(lr),
55
- _DirectionalNewton(max_dist, validate_step)
56
- ]
57
- super().__init__(params, modules)
58
-
@@ -1 +0,0 @@
1
- from .newton import ExactNewton
@@ -1,94 +0,0 @@
1
- from typing import Any, Literal
2
-
3
- import torch
4
-
5
- from ...modules import (
6
- LR,
7
- ClipNorm,
8
- FallbackLinearSystemSolvers,
9
- LinearSystemSolvers,
10
- LineSearches,
11
- get_line_search,
12
- )
13
- from ...modules import ExactNewton as _ExactNewton
14
- from ..modular import Modular
15
-
16
-
17
- class ExactNewton(Modular):
18
- """Peforms an exact Newton step using batched autograd. Note that torch.func would be way more efficient
19
- but much more restrictive to what operations are allowed (I will add it at some point).
20
-
21
- Args:
22
- params: iterable of parameters to optimize or dicts defining parameter groups.
23
- lr (float, optional): learning rate. Defaults to 1.
24
- tikhonov (float, optional):
25
- tikhonov regularization (constant value added to the diagonal of the hessian). Defaults to 0.
26
- solver (LinearSystemSolvers, optional):
27
- solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
28
- fallback (FallbackLinearSystemSolvers, optional):
29
- what to do if solver fails. Defaults to "safe_diag"
30
- (takes nonzero diagonal elements, or fallbacks to gradient descent if all elements are 0).
31
- max_norm (float, optional):
32
- clips the newton step to L2 norm to avoid instability by giant steps.
33
- A mauch better way is to use trust region methods. I haven't implemented any
34
- but you can use `tz.optim.wrappers.scipy.ScipyMinimize` with one of the trust region methods.
35
- Defaults to None.
36
- validate (bool, optional):
37
- validate if the step didn't increase the loss by `loss * tol` with an additional forward pass.
38
- If not, undo the step and perform a gradient descent step.
39
- tol (float, optional):
40
- only has effect if `validate` is enabled.
41
- If loss increased by `loss * tol`, perform gradient descent step.
42
- Set this to 0 to guarantee that loss always decreases. Defaults to 1.
43
- gd_lr (float, optional):
44
- only has effect if `validate` is enabled.
45
- Gradient descent step learning rate. Defaults to 1e-2.
46
- line_search (OptimizerModule | None, optional): line search module, can be None. Defaults to None.
47
- batched_hessian (bool, optional):
48
- whether to use experimental pytorch vmap-vectorized hessian calculation. As per pytorch docs,
49
- should be faster, but this feature being experimental, there may be performance cliffs.
50
- Defaults to True.
51
- diag (False, optional):
52
- only use the diagonal of the hessian. This will still calculate the full hessian!
53
- This is mainly useful for benchmarking.
54
- """
55
- def __init__(
56
- self,
57
- params,
58
- lr: float = 1,
59
- tikhonov: float | Literal['eig'] = 0.0,
60
- solver: LinearSystemSolvers = "cholesky_lu",
61
- fallback: FallbackLinearSystemSolvers = "safe_diag",
62
- max_norm: float | None = None,
63
- validate=False,
64
- tol: float = 1,
65
- gd_lr = 1e-2,
66
- line_search: LineSearches | None = None,
67
- batched_hessian = True,
68
-
69
- diag: bool = False,
70
- ):
71
- modules: list[Any] = [
72
- _ExactNewton(
73
- tikhonov=tikhonov,
74
- batched_hessian=batched_hessian,
75
- solver=solver,
76
- fallback=fallback,
77
- validate=validate,
78
- tol = tol,
79
- gd_lr=gd_lr,
80
- diag = diag,
81
- ),
82
- ]
83
-
84
- if max_norm is not None:
85
- modules.append(ClipNorm(max_norm))
86
-
87
- modules.append(LR(lr))
88
-
89
- if line_search is not None:
90
- modules.append(get_line_search(line_search))
91
-
92
- super().__init__(params, modules)
93
-
94
-
@@ -1,4 +0,0 @@
1
- from .fdm import FDM, FDMWrapper
2
- from .newton_fdm import NewtonFDM, RandomSubspaceNewtonFDM
3
- from .rfdm import RandomGaussianSmoothing, RandomizedFDM, RandomizedFDMWrapper, SPSA
4
- from .rs import RandomSearch, CyclicRS
@@ -1,87 +0,0 @@
1
- from typing import Literal
2
-
3
- import torch
4
-
5
- from ...modules import FDM as _FDM, WrapClosure, SGD, WeightDecay, LR
6
- from ...modules.gradient_approximation._fd_formulas import _FD_Formulas
7
- from ..modular import Modular
8
-
9
-
10
- class FDM(Modular):
11
- """Gradient approximation via finite difference.
12
-
13
- This performs `n + 1` evaluations per step with `forward` and `backward` formulas,
14
- and `2 * n` with `central` formula, where n is the number of parameters.
15
-
16
- Args:
17
- params: iterable of parameters to optimize or dicts defining parameter groups.
18
- lr (float, optional): learning rate. Defaults to 1e-3.
19
- eps (float, optional): finite difference epsilon. Defaults to 1e-3.
20
- formula (_FD_Formulas, optional): finite difference formula. Defaults to "forward".
21
- n_points (T.Literal[2, 3], optional): number of points for finite difference formula, 2 or 3. Defaults to 2.
22
- momentum (float, optional): momentum. Defaults to 0.
23
- dampening (float, optional): momentum dampening. Defaults to 0.
24
- nesterov (bool, optional):
25
- enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
26
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
27
- decoupled (bool, optional):
28
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
29
- """
30
- def __init__(
31
- self,
32
- params,
33
- lr: float = 1e-3,
34
- eps: float = 1e-3,
35
- formula: _FD_Formulas = "forward",
36
- n_points: Literal[2, 3] = 2,
37
- momentum: float = 0,
38
- dampening: float = 0,
39
- nesterov: bool = False,
40
- weight_decay: float = 0,
41
- decoupled=False,
42
-
43
- ):
44
- modules: list = [
45
- _FDM(eps = eps, formula=formula, n_points=n_points),
46
- SGD(momentum = momentum, dampening = dampening, weight_decay = weight_decay if not decoupled else 0, nesterov = nesterov),
47
- LR(lr),
48
-
49
- ]
50
- if decoupled: modules.append(WeightDecay(weight_decay))
51
- super().__init__(params, modules)
52
-
53
-
54
- class FDMWrapper(Modular):
55
- """Gradient approximation via finite difference. This wraps any other optimizer.
56
- This also supports optimizers that perform multiple gradient evaluations per step, like LBFGS.
57
-
58
- Exaple:
59
- ```
60
- lbfgs = torch.optim.LBFGS(params, lr = 1)
61
- fdm = FDMWrapper(optimizer = lbfgs)
62
- ```
63
-
64
- This performs n+1 evaluations per step with `forward` and `backward` formulas,
65
- and 2*n with `central` formula.
66
-
67
- Args:
68
- params: iterable of parameters to optimize or dicts defining parameter groups.
69
- optimizer (torch.optim.Optimizer): optimizer that will perform optimization using FDM-approximated gradients.
70
- eps (float, optional): finite difference epsilon. Defaults to 1e-3.
71
- formula (_FD_Formulas, optional): finite difference formula. Defaults to "forward".
72
- n_points (T.Literal[2, 3], optional): number of points for finite difference formula, 2 or 3. Defaults to 2.
73
- """
74
- def __init__(
75
- self,
76
- optimizer: torch.optim.Optimizer,
77
- eps: float = 1e-3,
78
- formula: _FD_Formulas = "forward",
79
- n_points: Literal[2, 3] = 2,
80
- ):
81
- modules = [
82
- _FDM(eps = eps, formula=formula, n_points=n_points, target = 'closure'),
83
- WrapClosure(optimizer)
84
- ]
85
- # some optimizers have `eps` setting in param groups too.
86
- # it should not be passed to FDM
87
- super().__init__([p for g in optimizer.param_groups.copy() for p in g['params']], modules)
@@ -1,146 +0,0 @@
1
- from typing import Any, Literal
2
- import torch
3
-
4
- from ...modules import (LR, FallbackLinearSystemSolvers,
5
- LinearSystemSolvers, LineSearches, ClipNorm)
6
- from ...modules import NewtonFDM as _NewtonFDM, get_line_search
7
- from ...modules.experimental.subspace import Proj2Masks, ProjRandom, Subspace
8
- from ..modular import Modular
9
-
10
-
11
- class NewtonFDM(Modular):
12
- """Newton method with gradient and hessian approximated via finite difference.
13
-
14
- This performs approximately `4 * n^2 + 1` evaluations per step;
15
- if `diag` is True, performs `n * 2 + 1` evaluations per step.
16
-
17
- Args:
18
- params: iterable of parameters to optimize or dicts defining parameter groups.
19
- lr (float, optional): learning rate.
20
- eps (float, optional): epsilon for finite difference.
21
- Note that with float32 this needs to be quite high to avoid numerical instability. Defaults to 1e-2.
22
- diag (bool, optional): whether to only approximate diagonal elements of the hessian.
23
- This also ignores `solver` if True. Defaults to False.
24
- solver (LinearSystemSolvers, optional):
25
- solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
26
- fallback (FallbackLinearSystemSolvers, optional):
27
- what to do if solver fails. Defaults to "safe_diag"
28
- (takes nonzero diagonal elements, or fallbacks to gradient descent if all elements are 0).
29
- validate (bool, optional):
30
- validate if the step didn't increase the loss by `loss * tol` with an additional forward pass.
31
- If not, undo the step and perform a gradient descent step.
32
- tol (float, optional):
33
- only has effect if `validate` is enabled.
34
- If loss increased by `loss * tol`, perform gradient descent step.
35
- Set this to 0 to guarantee that loss always decreases. Defaults to 1.
36
- gd_lr (float, optional):
37
- only has effect if `validate` is enabled.
38
- Gradient descent step learning rate. Defaults to 1e-2.
39
- line_search (OptimizerModule | None, optional): line search module, can be None. Defaults to 'brent'.
40
- """
41
- def __init__(
42
- self,
43
- params,
44
- lr: float = 1,
45
- eps: float = 1e-2,
46
- diag=False,
47
- solver: LinearSystemSolvers = "cholesky_lu",
48
- fallback: FallbackLinearSystemSolvers = "safe_diag",
49
- max_norm: float | None = None,
50
- validate=False,
51
- tol: float = 2,
52
- gd_lr = 1e-2,
53
- line_search: LineSearches | None = 'brent',
54
- ):
55
- modules: list[Any] = [
56
- _NewtonFDM(eps = eps, diag = diag, solver=solver, fallback=fallback, validate=validate, tol=tol, gd_lr=gd_lr),
57
- ]
58
-
59
- if max_norm is not None:
60
- modules.append(ClipNorm(max_norm))
61
-
62
- modules.append(LR(lr))
63
-
64
- if line_search is not None:
65
- modules.append(get_line_search(line_search))
66
-
67
- super().__init__(params, modules)
68
-
69
-
70
- class RandomSubspaceNewtonFDM(Modular):
71
- """This projects the parameters into a smaller dimensional subspace,
72
- making approximating the hessian via finite difference feasible.
73
-
74
- This performs approximately `4 * subspace_ndim^2 + 1` evaluations per step;
75
- if `diag` is True, performs `subspace_ndim * 2 + 1` evaluations per step.
76
-
77
- Args:
78
- params: iterable of parameters to optimize or dicts defining parameter groups.
79
- subspace_ndim (float, optional): number of random subspace dimensions.
80
- lr (float, optional): learning rate.
81
- eps (float, optional): epsilon for finite difference.
82
- Note that with float32 this needs to be quite high to avoid numerical instability. Defaults to 1e-2.
83
- diag (bool, optional): whether to only approximate diagonal elements of the hessian.
84
- solver (LinearSystemSolvers, optional):
85
- solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
86
- fallback (FallbackLinearSystemSolvers, optional):
87
- what to do if solver fails. Defaults to "safe_diag"
88
- (takes nonzero diagonal elements, or fallbacks to gradient descent if all elements are 0).
89
- validate (bool, optional):
90
- validate if the step didn't increase the loss by `loss * tol` with an additional forward pass.
91
- If not, undo the step and perform a gradient descent step.
92
- tol (float, optional):
93
- only has effect if `validate` is enabled.
94
- If loss increased by `loss * tol`, perform gradient descent step.
95
- Set this to 0 to guarantee that loss always decreases. Defaults to 1.
96
- gd_lr (float, optional):
97
- only has effect if `validate` is enabled.
98
- Gradient descent step learning rate. Defaults to 1e-2.
99
- line_search (OptimizerModule | None, optional): line search module, can be None. Defaults to BacktrackingLS().
100
- randomize_every (float, optional): generates new random projections every n steps. Defaults to 1.
101
- """
102
- def __init__(
103
- self,
104
- params,
105
- subspace_ndim: int = 3,
106
- lr: float = 1,
107
- eps: float = 1e-2,
108
- diag=False,
109
- solver: LinearSystemSolvers = "cholesky_lu",
110
- fallback: FallbackLinearSystemSolvers = "safe_diag",
111
- max_norm: float | None = None,
112
- validate=False,
113
- tol: float = 2,
114
- gd_lr = 1e-2,
115
- line_search: LineSearches | None = 'brent',
116
- randomize_every: int = 1,
117
- ):
118
- if subspace_ndim == 1: projections = [ProjRandom(1)]
119
- else:
120
- projections: list[Any] = [Proj2Masks(subspace_ndim//2)]
121
- if subspace_ndim % 2 == 1: projections.append(ProjRandom(1))
122
-
123
- modules: list[Any] = [
124
- Subspace(
125
- modules = _NewtonFDM(
126
- eps = eps,
127
- diag = diag,
128
- solver=solver,
129
- fallback=fallback,
130
- validate=validate,
131
- tol=tol,
132
- gd_lr=gd_lr
133
- ),
134
- projections = projections,
135
- update_every=randomize_every),
136
- ]
137
- if max_norm is not None:
138
- modules.append(ClipNorm(max_norm))
139
-
140
- modules.append(LR(lr))
141
-
142
- if line_search is not None:
143
- modules.append(get_line_search(line_search))
144
-
145
- super().__init__(params, modules)
146
-