torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -1,105 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- from functools import partial
3
- from typing import Any, Literal
4
-
5
- import torch
6
-
7
- from ...core import (
8
- OptimizationVars,
9
- OptimizerModule,
10
- _ClosureType,
11
- _maybe_pass_backward,
12
- _ScalarLoss,
13
- )
14
- from ...tensorlist import TensorList
15
-
16
-
17
- class GradientApproximatorBase(OptimizerModule, ABC):
18
- """Base gradient approximator class. This is an abstract class, please don't use it as the optimizer.
19
-
20
- When inheriting from this class the easiest way is to override `_make_ascent`, which should
21
- return the ascent direction (like approximated gradient).
22
-
23
- Args:
24
- defaults (dict[str, Any]): defaults
25
- requires_fx0 (bool):
26
- if True, makes sure to calculate fx0 beforehand.
27
- This means `_make_ascent` will always receive a pre-calculated `fx0` that won't be None.
28
-
29
- target (str, optional):
30
- determines what this module sets.
31
-
32
- "ascent" - it creates a new ascent direction but doesn't treat is as gradient.
33
-
34
- "grad" - it creates the gradient and sets it to `.grad` attributes (default).
35
-
36
- "closure" - it makes a new closure that sets the estimated gradient to the `.grad` attributes.
37
- """
38
- def __init__(self, defaults: dict[str, Any], requires_fx0: bool, target: Literal['ascent', 'grad', 'closure']):
39
- super().__init__(defaults, target)
40
- self.requires_fx0 = requires_fx0
41
-
42
- def _step_make_closure_(self, vars: OptimizationVars, params: TensorList):
43
- if vars.closure is None: raise ValueError("gradient approximation requires closure")
44
- closure = vars.closure
45
-
46
- if self.requires_fx0: fx0 = vars.evaluate_fx0_(False)
47
- else: fx0 = vars.fx0
48
-
49
- def new_closure(backward=True) -> _ScalarLoss:
50
- if backward:
51
- g, ret_fx0, ret_fx0_approx = self._make_ascent(closure, params, fx0)
52
- params.set_grad_(g)
53
-
54
- if ret_fx0 is None: return ret_fx0_approx # type:ignore
55
- return ret_fx0
56
-
57
- return closure(False)
58
-
59
- vars.closure = new_closure
60
-
61
- def _step_make_target_(self, vars: OptimizationVars, params: TensorList):
62
- if vars.closure is None: raise ValueError("gradient approximation requires closure")
63
-
64
- if self.requires_fx0: fx0 = vars.evaluate_fx0_(False)
65
- else: fx0 = vars.fx0
66
-
67
- g, vars.fx0, vars.fx0_approx = self._make_ascent(vars.closure, params, fx0)
68
- if self._default_step_target == 'ascent': vars.ascent = g
69
- elif self._default_step_target == 'grad': vars.set_grad_(g, params)
70
- else: raise ValueError(f"Unknown target {self._default_step_target}")
71
-
72
- @torch.no_grad
73
- def step(self, vars: OptimizationVars):
74
- params = self.get_params()
75
- if self._default_step_target == 'closure':
76
- self._step_make_closure_(vars, params)
77
-
78
- else:
79
- self._step_make_target_(vars, params)
80
-
81
- return self._update_params_or_step_with_next(vars, params)
82
-
83
- @abstractmethod
84
- @torch.no_grad
85
- def _make_ascent(
86
- self,
87
- # vars: OptimizationVars,
88
- closure: _ClosureType,
89
- params: TensorList,
90
- fx0: Any,
91
- ) -> tuple[TensorList, _ScalarLoss | None, _ScalarLoss | None]:
92
- """This should return a tuple of 3 elements:
93
-
94
- .. code:: py
95
-
96
- (ascent, fx0, fx0_approx)
97
-
98
- Args:
99
- closure (_ClosureType): closure
100
- params (TensorList): parameters
101
- fx0 (Any): fx0, can be None unless :target:`requires_fx0` is True on this module.
102
-
103
- Returns:
104
- (ascent, fx0, fx0_approx)
105
- """
@@ -1,125 +0,0 @@
1
- from typing import Literal, Any
2
- from warnings import warn
3
- import torch
4
-
5
- from ...utils.python_tools import _ScalarLoss
6
- from ...tensorlist import TensorList
7
- from ...core import _ClosureType, OptimizerModule, OptimizationVars
8
- from ._fd_formulas import _FD_Formulas
9
- from .base_approximator import GradientApproximatorBase
10
-
11
- def _two_point_fd_(closure: _ClosureType, idx: int, pvec: torch.Tensor, gvec: torch.Tensor, eps: _ScalarLoss, fx0: _ScalarLoss, ):
12
- """Two point finite difference (same signature for all other finite differences functions).
13
-
14
- Args:
15
- closure (Callable): A closure that reevaluates the model and returns the loss.
16
- idx (int): Flat index of the current parameter.
17
- pvec (Tensor): Flattened view of the current parameter tensor.
18
- gvec (Tensor): Flattened view of the current parameter tensor gradient.
19
- eps (float): Finite difference epsilon.
20
- fx0 (ScalarType): Loss at fx0, to avoid reevaluating it each time. On some functions can be None when it isn't needed.
21
-
22
- Returns:
23
- This modifies `gvec` in place.
24
- This returns loss, not necessarily at fx0 (for example central difference never evaluate at fx0).
25
- So this should be assigned to fx0_approx.
26
- """
27
- pvec[idx] += eps
28
- fx1 = closure(False)
29
- gvec[idx] = (fx1 - fx0) / eps
30
- pvec[idx] -= eps
31
- return fx0
32
-
33
- def _two_point_bd_(closure: _ClosureType, idx: int, pvec: torch.Tensor, gvec: torch.Tensor, eps: _ScalarLoss, fx0: _ScalarLoss, ):
34
- pvec[idx] += eps
35
- fx1 = closure(False)
36
- gvec[idx] = (fx0 - fx1) / eps
37
- pvec[idx] -= eps
38
- return fx0
39
-
40
- def _two_point_cd_(closure: _ClosureType, idx: int, pvec: torch.Tensor, gvec: torch.Tensor, eps: _ScalarLoss, fx0 = None, ):
41
- pvec[idx] += eps
42
- fxplus = closure(False)
43
- pvec[idx] -= eps * 2
44
- fxminus = closure(False)
45
- gvec[idx] = (fxplus - fxminus) / (2 * eps)
46
- pvec[idx] += eps
47
- return fxplus
48
-
49
- def _three_point_fd_(closure: _ClosureType, idx: int, pvec: torch.Tensor, gvec: torch.Tensor, eps: _ScalarLoss, fx0: _ScalarLoss, ):
50
- pvec[idx] += eps
51
- fx1 = closure(False)
52
- pvec[idx] += eps
53
- fx2 = closure(False)
54
- gvec[idx] = (-3*fx0 + 4*fx1 - fx2) / (2 * eps)
55
- pvec[idx] -= 2 * eps
56
- return fx0
57
-
58
- def _three_point_bd_(closure: _ClosureType, idx: int, pvec: torch.Tensor, gvec: torch.Tensor, eps: _ScalarLoss, fx0: _ScalarLoss, ):
59
- pvec[idx] -= eps
60
- fx1 = closure(False)
61
- pvec[idx] -= eps
62
- fx2 = closure(False)
63
- gvec[idx] = (fx2 - 4*fx1 + 3*fx0) / (2 * eps)
64
- pvec[idx] += 2 * eps
65
- return fx0
66
-
67
-
68
- class FDM(GradientApproximatorBase):
69
- """Gradient approximation via finite difference.
70
-
71
- This performs :math:`num_parameters + 1` or :math:`num_parameters * 2` evaluations per step, depending on formula.
72
-
73
- Args:
74
- eps (float, optional): finite difference epsilon. Defaults to 1e-5.
75
- formula (_FD_Formulas, optional): finite difference formula. Defaults to 'forward'.
76
- n_points (T.Literal[2, 3], optional): number of points, 2 or 3. Defaults to 2.
77
- target (str, optional):
78
- determines what this module sets.
79
-
80
- "ascent" - it creates a new ascent direction but doesn't treat is as gradient.
81
-
82
- "grad" - it creates the gradient and sets it to `.grad` attributes (default).
83
-
84
- "closure" - it makes a new closure that sets the estimated gradient to the `.grad` attributes.
85
- """
86
- def __init__(
87
- self,
88
- eps: float = 1e-5,
89
- formula: _FD_Formulas = "forward",
90
- n_points: Literal[2, 3] = 2,
91
- target: Literal["ascent", "grad", "closure"] = "grad",
92
- ):
93
- defaults = dict(eps = eps)
94
-
95
- if formula == 'central':
96
- self._finite_difference_ = _two_point_cd_ # this is both 2 and 3 point formula
97
- requires_fx0 = False
98
-
99
- elif formula == 'forward':
100
- if n_points == 2: self._finite_difference_ = _two_point_fd_
101
- else: self._finite_difference_ = _three_point_fd_
102
- requires_fx0 = True
103
-
104
- elif formula == 'backward':
105
- if n_points == 2: self._finite_difference_ = _two_point_bd_
106
- else: self._finite_difference_ = _three_point_bd_
107
- requires_fx0 = True
108
-
109
- else: raise ValueError(f'{formula} is not valid.')
110
-
111
- super().__init__(defaults, requires_fx0=requires_fx0, target = target)
112
-
113
- @torch.no_grad
114
- def _make_ascent(self, closure, params, fx0):
115
- grads = params.zeros_like()
116
- epsilons = self.get_group_key('eps')
117
-
118
- fx0_approx = None
119
- for p, g, eps in zip(params, grads, epsilons):
120
- flat_param = p.view(-1)
121
- flat_grad = g.view(-1)
122
- for idx in range(flat_param.numel()):
123
- fx0_approx = self._finite_difference_(closure, idx, flat_param, flat_grad, eps, fx0)
124
-
125
- return grads, fx0, fx0_approx
@@ -1,163 +0,0 @@
1
- from collections.abc import Iterable
2
- from typing import Literal
3
-
4
- import torch
5
- import torch.autograd.forward_ad as fwAD
6
-
7
- from ...core import OptimizerModule, _ClosureType
8
- from ...tensorlist import TensorList
9
- from ...random import Distributions
10
- from ...utils.torch_tools import swap_tensors_no_use_count_check
11
- from .base_approximator import GradientApproximatorBase
12
-
13
- def get_forward_gradient(
14
- params: Iterable[torch.Tensor],
15
- closure: _ClosureType,
16
- n_samples: int,
17
- distribution: Distributions,
18
- mode: Literal["jvp", "grad", "fd"],
19
- fd_eps: float = 1e-4,
20
- ):
21
- """Evaluates forward gradient of a closure w.r.t iterable of parameters with a random tangent vector.
22
-
23
- Args:
24
- params (Iterable[torch.Tensor]): iterable of parameters of the model.
25
- closure (_ClosureType):
26
- A closure that reevaluates the model and returns the loss.
27
- Closure must accept `backward = True` boolean argument. Forward gradient will always call it as
28
- `closure(False)`, unless `mode = "grad"` which requires a backward pass.
29
- n_samples (int): number of forward gradients to evaluate and average.
30
- distribution (Distributions): distribution for random tangent vector.
31
- mode (str):
32
- "jvp" - uses forward mode AD, usually slightly slower than backward mode AD but uses significantly less memory.
33
-
34
- "grad" - evaluates gradient with `loss.backward()` which may be faster but uses all the memory, mainly useful for
35
- benchmarking as there is probably no point in forward gradient if full gradient is available.
36
-
37
- "fd" - uses finite difference to estimate JVP in two forward passes,
38
- doesn't require the objective to be autodiffable. Equivalent to randomized FDM.
39
-
40
- fd_eps (float, optional): epsilon for finite difference, only has effect if mode is "fd". Defaults to 1e-4.
41
-
42
- Returns:
43
- TensorList: list of estimated gradients of the same structure and shape as `params`.
44
-
45
- Reference:
46
- Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., & Torr, P. (2022).
47
- Gradients without backpropagation. arXiv preprint arXiv:2202.08587.
48
- https://arxiv.org/abs/2202.08587
49
- """
50
- if not isinstance(params, TensorList): params = TensorList(params)
51
- params = params.with_requires_grad()
52
-
53
- orig_params = None
54
- grad = None
55
- loss = None
56
- for _ in range(n_samples):
57
-
58
- # generate random vector
59
- tangents = params.sample_like(fd_eps if mode == 'fd' else 1, distribution)
60
-
61
- if mode == 'jvp':
62
- if orig_params is None:
63
- orig_params = params.clone().requires_grad_()
64
-
65
- # evaluate jvp with it
66
- with fwAD.dual_level():
67
-
68
- # swap to duals
69
- for param, clone, tangent in zip(params, orig_params, tangents):
70
- dual = fwAD.make_dual(clone, tangent)
71
- torch.utils.swap_tensors(param, dual)
72
-
73
- loss = closure(False)
74
- jvp = fwAD.unpack_dual(loss).tangent
75
-
76
- elif mode == 'grad':
77
- with torch.enable_grad(): loss = closure()
78
- jvp = tangents.mul(params.ensure_grad_().grad).sum()
79
-
80
- elif mode == 'fd':
81
- loss = closure(False)
82
- params += tangents
83
- loss2 = closure(False)
84
- params -= tangents
85
- jvp = (loss2 - loss) / fd_eps**2
86
-
87
- else:
88
- raise ValueError(mode)
89
-
90
- # update grad estimate
91
- if grad is None: grad = tangents * jvp
92
- else: grad += tangents * jvp
93
-
94
- # swap back to original params
95
- if orig_params is not None:
96
- for param, orig in zip(params, orig_params):
97
- swap_tensors_no_use_count_check(param, orig)
98
-
99
- assert grad is not None
100
- assert loss is not None
101
- if n_samples > 1:
102
- grad /= n_samples
103
-
104
- return grad, loss
105
-
106
- class ForwardGradient(GradientApproximatorBase):
107
- """Evaluates jacobian-vector product with a random vector using forward mode autodiff (torch.autograd.forward_ad), which is
108
- the true directional derivative in the direction of that vector.
109
-
110
- Args:
111
- n_samples (int): number of forward gradients to evaluate and average.
112
- distribution (Distributions): distribution for random tangent vector.
113
- mode (str):
114
- "jvp" - uses forward mode AD, usually slightly slower than backward mode AD but uses significantly less memory,
115
- because it doesn't have to store intermediate activations.
116
-
117
- "grad" - evaluates gradient with `loss.backward()` which may be faster but uses all the memory, mainly useful for
118
- benchmarking as there is probably no point in forward gradient if full gradient is available.
119
-
120
- "fd" - uses finite difference to estimate JVP in two forward passes,
121
- doesn't require the objective to be autodiffable. Equivalent to randomized FDM.
122
-
123
- fd_eps (float, optional): epsilon for finite difference, only has effect if mode is "fd". Defaults to 1e-4.
124
- target (str, optional):
125
- determines what this module sets.
126
-
127
- "ascent" - it creates a new ascent direction but doesn't treat is as gradient.
128
-
129
- "grad" - it creates the gradient and sets it to `.grad` attributes (default).
130
-
131
- "closure" - it makes a new closure that sets the estimated gradient to the `.grad` attributes.
132
-
133
- Reference:
134
- Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., & Torr, P. (2022).
135
- Gradients without backpropagation. arXiv preprint arXiv:2202.08587.
136
- https://arxiv.org/abs/2202.08587
137
- """
138
- def __init__(
139
- self,
140
- n_samples: int = 1,
141
- distribution: Distributions = "normal",
142
- mode: Literal["jvp", "grad", "fd"] = "jvp",
143
- fd_eps: float = 1e-4,
144
- target: Literal['ascent', 'grad', 'closure'] = 'grad',
145
- ):
146
- super().__init__({}, requires_fx0=False, target = target)
147
- self.distribution: Distributions = distribution
148
- self.n_samples = n_samples
149
- self.mode: Literal["jvp", "grad", "fd"] = mode
150
- self.fd_eps = fd_eps
151
-
152
-
153
- def _make_ascent(self, closure, params, fx0):
154
- g, fx0 = get_forward_gradient(
155
- params=params,
156
- closure=closure,
157
- n_samples=self.n_samples,
158
- distribution=self.distribution,
159
- mode=self.mode,
160
- fd_eps=self.fd_eps,
161
- )
162
-
163
- return g, fx0, None
@@ -1,198 +0,0 @@
1
- import typing as T
2
-
3
- import torch
4
-
5
- from ...utils.python_tools import _ScalarLoss
6
- from ...tensorlist import TensorList
7
- from ...core import _ClosureType, OptimizerModule
8
- from ..second_order.newton import (LINEAR_SYSTEM_SOLVERS,
9
- FallbackLinearSystemSolvers,
10
- LinearSystemSolvers, _fallback_gd)
11
- from ._fd_formulas import _FD_Formulas
12
-
13
-
14
- def _three_point_2cd_(
15
- closure: _ClosureType,
16
- idx1: int,
17
- idx2: int,
18
- p1: torch.Tensor,
19
- p2: torch.Tensor,
20
- g1: torch.Tensor,
21
- hessian: torch.Tensor,
22
- eps1: _ScalarLoss,
23
- eps2: _ScalarLoss,
24
- i1: int,
25
- i2: int,
26
- fx0: _ScalarLoss,
27
- ):
28
- """Second order three point finite differences (same signature for all other 2nd order finite differences functions).
29
-
30
- Args:
31
- closure (ClosureType): _description_
32
- idx1 (int): _description_
33
- idx2 (int): _description_
34
- p1 (torch.Tensor): _description_
35
- p2 (torch.Tensor): _description_
36
- g1 (torch.Tensor): _description_
37
- g2 (torch.Tensor): _description_
38
- hessian (torch.Tensor): _description_
39
- eps1 (ScalarType): _description_
40
- eps2 (ScalarType): _description_
41
- i1 (int): _description_
42
- i23 (int): _description_
43
- fx0 (ScalarType): _description_
44
-
45
- """
46
- # same param
47
- if i1 == i2 and idx1 == idx2:
48
- p1[idx1] += eps1
49
- fxplus = closure(False)
50
-
51
- p1[idx1] -= 2*eps1
52
- fxminus = closure(False)
53
-
54
- p1[idx1] += eps1
55
-
56
- g1[idx1] = (fxplus - fxminus) / (2 * eps1)
57
- hessian[i1, i2] = (fxplus - 2*fx0 + fxminus) / eps1**2
58
-
59
- else:
60
- p1[idx1] += eps1
61
- p2[idx2] += eps2
62
- fxpp = closure(False)
63
- p1[idx1] -= eps1*2
64
- fxnp = closure(False)
65
- p2[idx2] -= eps2*2
66
- fxnn = closure(False)
67
- p1[idx1] += eps1*2
68
- fxpn = closure(False)
69
-
70
- p1[idx1] -= eps1
71
- p2[idx2] += eps2
72
-
73
- hessian[i1, i2] = (fxpp - fxpn - fxnp + fxnn) / (4 * eps1 * eps2)
74
-
75
-
76
- class NewtonFDM(OptimizerModule):
77
- """Newton method with gradient and hessian approximated via finite difference.
78
-
79
- Args:
80
- eps (float, optional):
81
- epsilon for finite difference.
82
- Note that with float32 this needs to be quite high to avoid numerical instability. Defaults to 1e-2.
83
- diag (bool, optional):
84
- whether to only approximate diagonal elements of the hessian.
85
- If true, ignores `solver` and `fallback`. Defaults to False.
86
- solver (LinearSystemSolvers, optional):
87
- solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
88
- fallback (FallbackLinearSystemSolvers, optional):
89
- what to do if solver fails. Defaults to "safe_diag"
90
- (takes nonzero diagonal elements, or fallbacks to gradient descent if all elements are 0).
91
- validate (bool, optional):
92
- validate if the step didn't increase the loss by `loss * tol` with an additional forward pass.
93
- If not, undo the step and perform a gradient descent step.
94
- tol (float, optional):
95
- only has effect if `validate` is enabled.
96
- If loss increased by `loss * tol`, perform gradient descent step.
97
- Set this to 0 to guarantee that loss always decreases. Defaults to 1.
98
- gd_lr (float, optional):
99
- only has effect if `validate` is enabled.
100
- Gradient descent step learning rate. Defaults to 1e-2.
101
-
102
- """
103
- def __init__(
104
- self,
105
- eps: float = 1e-2,
106
- diag=False,
107
- solver: LinearSystemSolvers = "cholesky_lu",
108
- fallback: FallbackLinearSystemSolvers = "safe_diag",
109
- validate=False,
110
- tol: float = 1,
111
- gd_lr = 1e-2,
112
- ):
113
- defaults = dict(eps = eps)
114
- super().__init__(defaults)
115
- self.diag = diag
116
- self.solver = LINEAR_SYSTEM_SOLVERS[solver]
117
- self.fallback = LINEAR_SYSTEM_SOLVERS[fallback]
118
-
119
- self.validate = validate
120
- self.gd_lr = gd_lr
121
- self.tol = tol
122
-
123
- @torch.no_grad
124
- def step(self, vars):
125
- """Returns a new ascent direction."""
126
- if vars.closure is None: raise ValueError('NewtonFDM requires a closure.')
127
- if vars.ascent is not None: raise ValueError('NewtonFDM got ascent direction')
128
-
129
- params = self.get_params()
130
- epsilons = self.get_group_key('eps')
131
-
132
- # evaluate fx0.
133
- if vars.fx0 is None: vars.fx0 = vars.closure(False)
134
-
135
- # evaluate gradients and hessian via finite differences.
136
- grads = params.zeros_like()
137
- numel = params.total_numel()
138
- hessian = torch.zeros((numel, numel), dtype = params[0].dtype, device = params[0].device)
139
-
140
- cur1 = 0
141
- for p1, g1, eps1 in zip(params, grads, epsilons):
142
- flat_param1 = p1.view(-1)
143
- flat_grad1 = g1.view(-1)
144
- for idx1 in range(flat_param1.numel()):
145
-
146
- cur2 = 0
147
- for p2, eps2 in zip(params, epsilons):
148
-
149
- flat_param2 = p2.view(-1)
150
- for idx2 in range(flat_param2.numel()):
151
- if self.diag and (idx1 != idx2 or cur1 != cur2):
152
- cur2 += 1
153
- continue
154
- _three_point_2cd_(
155
- closure = vars.closure,
156
- idx1 = idx1,
157
- idx2 = idx2,
158
- p1 = flat_param1,
159
- p2 = flat_param2,
160
- g1 = flat_grad1,
161
- hessian = hessian,
162
- eps1 = eps1,
163
- eps2 = eps2,
164
- fx0 = vars.fx0,
165
- i1 = cur1,
166
- i2 = cur2,
167
- )
168
- cur2 += 1
169
- cur1 += 1
170
-
171
- gvec = grads.to_vec()
172
- if self.diag:
173
- hdiag = hessian.diag()
174
- hdiag[hdiag == 0] = 1
175
- newton_step = gvec / hdiag
176
- else:
177
- newton_step, success = self.solver(hessian, gvec)
178
- if not success:
179
- newton_step, success = self.fallback(hessian, gvec)
180
- if not success:
181
- newton_step, success = _fallback_gd(hessian, gvec)
182
-
183
- # update params or pass the gradients to the child.
184
- vars.ascent = grads.from_vec(newton_step)
185
-
186
-
187
- # validate if newton step decreased loss
188
- if self.validate:
189
-
190
- params.sub_(vars.ascent)
191
- fx1 = vars.closure(False)
192
- params.add_(vars.ascent)
193
-
194
- # if loss increases, set ascent direction to gvec times lr
195
- if fx1 - vars.fx0 > vars.fx0 * self.tol:
196
- vars.ascent = grads.from_vec(gvec) * self.gd_lr
197
-
198
- return self._update_params_or_step_with_next(vars, params)