torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -1,62 +0,0 @@
1
- import typing
2
-
3
- import torch
4
- try:
5
- import scipy.optimize as scopt
6
- except ModuleNotFoundError:
7
- scopt = typing.cast(typing.Any, None)
8
-
9
- from ...tensorlist import TensorList
10
- from ...core import OptimizationVars
11
-
12
- from .base_ls import LineSearchBase, MaxIterReached
13
-
14
- if typing.TYPE_CHECKING:
15
- import scipy.optimize as scopt
16
-
17
- class ScipyMinimizeScalarLS(LineSearchBase):
18
- """Line search via `scipy.optimize.minimize_scalar`. All args except maxiter are the same as for it.
19
-
20
- Args:
21
- method (Optional[str], optional): 'brent', 'golden' or 'bounded'. Defaults to None.
22
- maxiter (Optional[int], optional): hard limit on maximum number of function evaluations. Defaults to None.
23
- bracket (optional): bracket. Defaults to None.
24
- bounds (optional): bounds. Defaults to None.
25
- tol (Optional[float], optional): some kind of tolerance. Defaults to None.
26
- options (optional): options for method. Defaults to None.
27
- log_lrs (bool, optional): logs lrs and values into `_lrs`. Defaults to False.
28
- """
29
- def __init__(
30
- self,
31
- method: str | None = None,
32
- maxiter: int | None = None,
33
- bracket = None,
34
- bounds = None,
35
- tol: float | None = None,
36
- options = None,
37
- log_lrs = False,
38
- ):
39
- if scopt is None: raise ModuleNotFoundError("scipy is not installed")
40
- super().__init__({}, maxiter=maxiter, log_lrs=log_lrs)
41
- self.method = method
42
- self.tol = tol
43
- self.bracket = bracket
44
- self.bounds = bounds
45
- self.options = options
46
-
47
- @torch.no_grad
48
- def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
49
- try:
50
- res = scopt.minimize_scalar(
51
- self._evaluate_lr_ensure_float,
52
- args = (vars.closure, vars.ascent, params),
53
- method = self.method,
54
- tol = self.tol,
55
- bracket = self.bracket,
56
- bounds = self.bounds,
57
- options = self.options,
58
- ) # type:ignore
59
- except MaxIterReached:
60
- pass
61
-
62
- return float(self._best_lr)
@@ -1,12 +0,0 @@
1
- """Modules that use other modules."""
2
- # from .chain import Chain, ChainReturn
3
- import sys
4
-
5
- from .alternate import Alternate
6
- from .grafting import Graft, IntermoduleCautious, SignGrafting
7
- from .return_overrides import ReturnAscent, ReturnClosure, SetGrad
8
-
9
- # if sys.version_info[1] < 12:
10
- from .optimizer_wrapper import Wrap, WrapClosure
11
- # else:
12
- # from .optimizer_wrapper import Wrap, WrapClosure
@@ -1,65 +0,0 @@
1
- import random
2
- from collections.abc import Iterable
3
- from typing import Any, Literal
4
-
5
- from ...core import OptimizerModule, _Chainable
6
-
7
-
8
- class Alternate(OptimizerModule):
9
- """Alternates stepping with multiple modules.
10
-
11
- Args:
12
- modules (Iterable[OptimizerModule | Iterable[OptimizerModule]]): modules to alternate between.
13
- mode (int | list[int] | tuple[int] | "random"], optional):
14
- can be integer - number of repeats for all modules;
15
- list or tuple of integers per each module with number of repeats;
16
- "random" to pick module randomly each time. Defaults to 1.
17
- seed (int | None, optional): seed for "random" mode. Defaults to None.
18
- """
19
- def __init__(
20
- self,
21
- modules: Iterable[_Chainable],
22
- mode: int | list[int] | tuple[int] | Literal["random"] = 1,
23
- seed: int | None = None
24
- ):
25
- super().__init__({})
26
- modules = list(modules)
27
-
28
- for i,m in enumerate(modules):
29
- self._set_child_(i, m)
30
-
31
- self.random = random.Random(seed)
32
-
33
- if isinstance(mode, int): mode = [mode for _ in modules]
34
- self.mode: list[int] | tuple[int] | Literal['random'] = mode
35
-
36
- self.cur = 0
37
- if self.mode == 'random': self.remaining = 0
38
- else:
39
- self.remaining = self.mode[0]
40
- if len(self.mode) != len(self.children):
41
- raise ValueError(f"got {len(self.children)} modules but {len(mode)} repeats, they should be the same")
42
-
43
- def step(self, vars):
44
- if self.mode == 'random':
45
- module = self.random.choice(list(self.children.values()))
46
-
47
- else:
48
- if self.remaining == 0:
49
- self.cur += 1
50
-
51
- if self.cur >= len(self.mode):
52
- self.cur = 0
53
-
54
- if self.remaining == 0: self.remaining = self.mode[self.cur]
55
-
56
- module = self.children[self.cur]
57
-
58
- self.remaining -= 1
59
-
60
- if self.next_module is None:
61
- return module.step(vars)
62
-
63
- vars.ascent = module.return_ascent(vars)
64
- return self._update_params_or_step_with_next(vars)
65
-
@@ -1,195 +0,0 @@
1
- from collections.abc import Iterable
2
- from typing import Literal
3
- import torch
4
-
5
- from ...core import OptimizerModule
6
- from ...tensorlist import TensorList
7
-
8
-
9
- class Graft(OptimizerModule):
10
- """
11
- Optimizer grafting (magnitude#direction).
12
- Takes update of one optimizer and makes its norm same as update of another optimizer.
13
- Can be applied to all weights or layerwise.
14
-
15
- Args:
16
- magnitude (OptimizerModule | Iterable[OptimizerModule]):
17
- module to use magnitude from.
18
- If sequence of modules is provided, they will be chained.
19
- direction (OptimizerModule | Iterable[OptimizerModule]):
20
- module/modules to use direction from.
21
- If sequence of modules is provided, they will be chained.
22
- ord (int, optional): norm type. Defaults to 2.
23
- eps (float, optional): epsilon for numerical stability. Defaults to 1e-8.
24
- layerwise (bool, optional): whether to apply grafting layerwise. Defaults to False.
25
-
26
- reference
27
- *Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C.
28
- Learning Rate Grafting: Transferability of Optimizer Tuning.*
29
- """
30
- def __init__(
31
- self,
32
- magnitude: OptimizerModule | Iterable[OptimizerModule],
33
- direction: OptimizerModule | Iterable[OptimizerModule],
34
- ord: float = 2,
35
- eps: float = 1e-8,
36
- layerwise: bool = False,
37
- # TODO: channelwise
38
- ):
39
- super().__init__({})
40
- self._set_child_('magnitude', magnitude)
41
- self._set_child_('direction', direction)
42
- self.ord = ord
43
- self.eps = eps
44
- self.layerwise = layerwise
45
-
46
-
47
- @torch.no_grad
48
- def step(self, vars):
49
- state_copy = vars.copy(clone_ascent=True)
50
- magnitude = self.children['magnitude'].return_ascent(state_copy)
51
-
52
- if state_copy.grad is not None: vars.grad = state_copy.grad
53
- if state_copy.fx0 is not None: vars.fx0 = state_copy.fx0
54
- if state_copy.fx0_approx is not None: vars.fx0_approx = state_copy.fx0_approx
55
-
56
- direction = self.children['direction'].return_ascent(vars)
57
-
58
- if self.layerwise:
59
- M = magnitude.norm(self.ord)
60
- D = direction.norm(self.ord)
61
- D.select_set_(D == 0, M)
62
-
63
- else:
64
- M = magnitude.total_vector_norm(self.ord)
65
- D = direction.total_vector_norm(self.ord)
66
- if D == 0: D = M
67
-
68
- vars.ascent = direction.mul_(M / (D + self.eps))
69
- return self._update_params_or_step_with_next(vars)
70
-
71
-
72
-
73
- class SignGrafting(OptimizerModule):
74
- """Weight-wise grafting-like operation where sign of the ascent is taken from first module
75
- and magnitude from second module.
76
-
77
- Args:
78
- magnitude (OptimizerModule | Iterable[OptimizerModule]):
79
- module to take magnitude from.
80
- If sequence of modules is provided, they will be chained.
81
- sign (OptimizerModule | Iterable[OptimizerModule]):
82
- module to take sign from.
83
- If sequence of modules is provided, they will be chained.
84
- """
85
- def __init__(
86
- self,
87
- magnitude: OptimizerModule | Iterable[OptimizerModule],
88
- sign: OptimizerModule | Iterable[OptimizerModule],
89
- ):
90
- super().__init__({})
91
-
92
- self._set_child_('magnitude', magnitude)
93
- self._set_child_('sign', sign)
94
-
95
-
96
- @torch.no_grad
97
- def step(self, vars):
98
- state_copy = vars.copy(clone_ascent=True)
99
- magnitude = self.children['magnitude'].return_ascent(state_copy)
100
-
101
- # make sure to store grad and fx0 if it was calculated
102
- vars.update_attrs_(state_copy)
103
-
104
- sign = self.children['sign'].return_ascent(vars)
105
-
106
- vars.ascent = magnitude.copysign_(sign)
107
- return self._update_params_or_step_with_next(vars)
108
-
109
-
110
- class IntermoduleCautious(OptimizerModule):
111
- """Negates update for parameters where updates of two modules or module chains have inconsistent sign.
112
- Optionally normalizes the update by the number of parameters that are not masked.
113
-
114
- Args:
115
- main_module (OptimizerModule | Iterable[OptimizerModule]):
116
- main module or sequence of modules to chain, which update will be used with a consistency mask applied.
117
- compare_module (OptimizerModule | Iterable[OptimizerModule]):
118
- module or sequence of modules to chain, which update will be used to compute a consistency mask.
119
- Can also be set to `ascent` to compare to update that is passed `main_module`, or `grad` to compare
120
- to gradients.
121
- normalize (bool, optional):
122
- renormalize update after masking.
123
- only has effect when mode is 'zero'. Defaults to False.
124
- eps (float, optional): epsilon for normalization. Defaults to 1e-6.
125
- mode (str, optional):
126
- what to do with updates with inconsistent signs.
127
-
128
- "zero" - set them to zero (as in paper)
129
-
130
- "grad" - set them to the gradient
131
-
132
- "compare_module" - set them to `compare_module`'s update
133
-
134
- "negate" - negate them (same as using update magnitude and gradient sign)
135
- """
136
- def __init__(
137
- self,
138
- main_module: OptimizerModule | Iterable[OptimizerModule],
139
- compare_module: OptimizerModule | Iterable[OptimizerModule] | Literal['ascent', 'grad'],
140
- normalize=False,
141
- eps=1e-6,
142
- mode: Literal["zero", "grad", "backtrack", "compare_module"] = "zero",
143
- ):
144
- super().__init__({})
145
-
146
- self._set_child_('main',main_module)
147
- if isinstance(compare_module, str): self.compare_mode = compare_module
148
- else:
149
- self._set_child_('compare', compare_module)
150
- self.compare_mode = 'module'
151
- self.eps = eps
152
- self.normalize = normalize
153
- self.mode: Literal["zero", "grad", "backtrack", "compare_module"] = mode
154
-
155
- @torch.no_grad
156
- def step(self, vars):
157
- params = None
158
- state_copy = vars.copy(clone_ascent=True)
159
- ascent = self.children['main'].return_ascent(state_copy)
160
- vars.update_attrs_(state_copy)
161
-
162
- if self.compare_mode == 'module': compare = self.children['compare'].return_ascent(vars)
163
- else:
164
- params = self.get_params()
165
- if self.compare_mode == 'ascent': compare: TensorList = vars.maybe_use_grad_(params)
166
- elif self.compare_mode == 'grad': compare: TensorList = vars.maybe_compute_grad_(params)
167
- else: raise ValueError(f'Invalid compare_module: {self.compare_mode}')
168
-
169
- # mask will be > 0 for parameters where both signs are the same
170
- mask = (ascent * compare) > 0
171
-
172
- if self.mode == 'backtrack':
173
- ascent -= ascent.mul(2).mul_(mask.logical_not_())
174
-
175
- else:
176
- # normalize if mode is `zero`
177
- if self.normalize and self.mode == 'zero':
178
- fmask = mask.to(ascent[0].dtype)
179
- fmask /= fmask.total_mean() + self.eps
180
- else:
181
- fmask = mask
182
-
183
- # apply the mask
184
- ascent *= fmask
185
-
186
- if self.mode == 'grad':
187
- params = self.get_params()
188
- ascent += vars.maybe_compute_grad_(params) * mask.logical_not_()
189
-
190
- elif self.mode == 'compare_module':
191
- ascent += compare * mask.logical_not_()
192
-
193
- vars.ascent = ascent
194
- return self._update_params_or_step_with_next(vars, params)
195
-
@@ -1,173 +0,0 @@
1
- from collections.abc import Callable, Sequence
2
- from typing import Any, overload
3
-
4
- import torch
5
- from typing_extensions import Concatenate, ParamSpec
6
-
7
- from ...core import OptimizerModule
8
- from .return_overrides import SetGrad
9
-
10
- K = ParamSpec('K')
11
-
12
- class Wrap(OptimizerModule):
13
- """
14
- Wraps any torch.optim.Optimizer.
15
-
16
- Sets .grad attribute to the current update and steps with the `optimizer`.
17
-
18
- Additionally, if this is not the last module, this takes the update of `optimizer`,
19
- undoes it and passes to the next module instead. That means you can chain multiple
20
- optimizers together.
21
-
22
- Args:
23
- optimizer (torch.optim.Optimizer): optimizer to wrap,
24
- or a callable (class) that constructs the optimizer.
25
- kwargs:
26
- if class is passed, kwargs are passed to the constructor.
27
- parameters are passed separately and automatically
28
- which is the point of passing a constructor
29
- instead of an optimizer directly.
30
-
31
- This can be constructed in two ways.
32
- .. code-block:: python
33
- wrapper = OptimizerWrapper(torch.optim.SGD(model.parameters(), lr = 0.1))
34
- # or
35
- wrapper = OptimizerWrapper(torch.optim.SGD, lr = 0.1)
36
- """
37
-
38
- @overload
39
- def __init__(self, optimizer: torch.optim.Optimizer): ...
40
- @overload
41
- # def __init__[**K](
42
- def __init__(
43
- self,
44
- optimizer: Callable[Concatenate[Any, K], torch.optim.Optimizer],
45
- *args: K.args,
46
- **kwargs: K.kwargs,
47
- # optimizer: abc.Callable[..., torch.optim.Optimizer],
48
- # *args,
49
- # **kwargs,
50
- ): ...
51
- def __init__(self, optimizer, *args, **kwargs):
52
-
53
- super().__init__({})
54
- self._optimizer_cls: torch.optim.Optimizer | Callable[..., torch.optim.Optimizer] = optimizer
55
- self._args = args
56
- self._kwargs = kwargs
57
-
58
- def _initialize_(self, params, set_passed_params):
59
- """Initializes this optimizer and all children with the given parameters."""
60
- super()._initialize_(params, set_passed_params=set_passed_params)
61
- if isinstance(self._optimizer_cls, torch.optim.Optimizer) or not callable(self._optimizer_cls):
62
- self.optimizer = self._optimizer_cls
63
- else:
64
- self.optimizer = self._optimizer_cls(params, *self._args, **self._kwargs)
65
-
66
- @torch.no_grad
67
- def step(self, vars):
68
- # check attrs
69
- # if self.pass_closure:
70
- # if state.closure is None: raise ValueError('ClosureOptimizerWrapper requires closure.')
71
- # if state.ascent is not None:
72
- # raise ValueError('pass_closure = True, means ascent must be None (not sure though)')
73
-
74
- params = self.get_params()
75
-
76
- if self.next_module is None:
77
- # set grad to ascent and make a step with the optimizer
78
- g = vars.maybe_use_grad_(params)
79
- params.set_grad_(g)
80
- vars.fx0 = self.optimizer.step()
81
- return vars.get_loss()
82
-
83
-
84
- params_before_step = params.clone()
85
-
86
- g = vars.maybe_use_grad_(params)
87
- params.set_grad_(g)
88
- vars.fx0 = self.optimizer.step()
89
-
90
- # calculate update as difference in params
91
- vars.ascent = params_before_step - params
92
- params.set_(params_before_step)
93
- return self.next_module.step(vars)
94
-
95
-
96
- class WrapClosure(OptimizerModule):
97
- """
98
- Wraps any torch.optim.Optimizer. This only works with modules with :code:`target = "Closure"` argument.
99
- The modified closure will be passed to the optimizer.
100
-
101
- Alternative any module can be turned into a closure module by using :any:`MakeClosure` module,
102
- in that case this should be placed after MakeClosure.
103
-
104
- Args:
105
- optimizer (torch.optim.Optimizer): optimizer to wrap,
106
- or a callable (class) that constructs the optimizer.
107
- kwargs:
108
- if class is passed, kwargs are passed to the constructor.
109
- parameters are passed separately and automatically
110
- which is the point of passing a constructor
111
- instead of an optimizer directly.
112
-
113
- This can be constructed in two ways.
114
-
115
- .. code-block:: python
116
-
117
- wrapper = OptimizerWrapper(torch.optim.SGD(model.parameters(), lr = 0.1))
118
- # or
119
- wrapper = OptimizerWrapper(torch.optim.SGD, lr = 0.1)
120
-
121
- """
122
-
123
- @overload
124
- def __init__(self, optimizer: torch.optim.Optimizer,): ...
125
- @overload
126
- def __init__(
127
- self,
128
- optimizer: Callable[Concatenate[Any, K], torch.optim.Optimizer],
129
- *args: K.args,
130
- **kwargs: K.kwargs,
131
- # optimizer: abc.Callable[..., torch.optim.Optimizer],
132
- # *args,
133
- # **kwargs,
134
- ): ...
135
- def __init__(self, optimizer, *args, **kwargs):
136
-
137
- super().__init__({})
138
- self._optimizer_cls: torch.optim.Optimizer | Callable[..., torch.optim.Optimizer] = optimizer
139
- self._args = args
140
- self._kwargs = kwargs
141
-
142
- def _initialize_(self, params, set_passed_params):
143
- """Initializes this optimizer and all children with the given parameters."""
144
- super()._initialize_(params, set_passed_params=set_passed_params)
145
- if isinstance(self._optimizer_cls, torch.optim.Optimizer) or not callable(self._optimizer_cls):
146
- self.optimizer = self._optimizer_cls
147
- else:
148
- self.optimizer = self._optimizer_cls(params, *self._args, **self._kwargs)
149
-
150
- @torch.no_grad
151
- def step(self, vars):
152
- # check attrs
153
- # if self.pass_closure:
154
- # if state.closure is None: raise ValueError('ClosureOptimizerWrapper requires closure.')
155
- # if state.ascent is not None:
156
- # raise ValueError('pass_closure = True, means ascent must be None (not sure though)')
157
-
158
- params = self.get_params()
159
-
160
- if self.next_module is None:
161
- # set grad to ascent and make a step with the optimizer
162
- vars.fx0 = self.optimizer.step(vars.closure) # type:ignore
163
- return vars.get_loss()
164
-
165
-
166
- params_before_step = params.clone()
167
- vars.fx0 = self.optimizer.step(vars.closure) # type:ignore
168
-
169
- # calculate update as difference in params
170
- vars.ascent = params_before_step - params
171
- params.set_(params_before_step)
172
- return self.next_module.step(vars)
173
-
@@ -1,46 +0,0 @@
1
- import torch
2
- from ...tensorlist import TensorList
3
- from ...core import OptimizerModule, _get_loss, _ClosureType
4
-
5
- class SetGrad(OptimizerModule):
6
- """Doesn't update parameters, instead replaces all parameters `.grad` attribute with the current update.
7
- You can now step with any pytorch optimizer that utilises the `.grad` attribute."""
8
- def __init__(self):
9
- super().__init__({})
10
-
11
- @torch.no_grad
12
- def step(self, vars):
13
- if self.next_module is not None: raise ValueError("SetGrad can't have children")
14
- params = self.get_params()
15
- g = vars.maybe_use_grad_(params) # this may execute the closure which might be modified
16
- params.set_grad_(g)
17
- return vars.get_loss()
18
-
19
-
20
- class ReturnAscent(OptimizerModule):
21
- """Doesn't update parameters, instead returns the update as a TensorList of tensors."""
22
- def __init__(self):
23
- super().__init__({})
24
-
25
- @torch.no_grad
26
- def step(self, vars) -> TensorList: # type:ignore
27
- if self.next_module is not None: raise ValueError("ReturnAscent can't have children")
28
- params = self.get_params()
29
- update = vars.maybe_use_grad_(params) # this will execute the closure which might be modified
30
- return update
31
-
32
- class ReturnClosure(OptimizerModule):
33
- """Doesn't update parameters, instead returns the current modified closure.
34
- For example, if you put this after :code:`torchzero.modules.FDM(target = "closure")`,
35
- the closure will set `.grad` attribute to gradients approximated via finite difference.
36
- You can then pass that closure to something that requires closure like `torch.optim.LBFGS`."""
37
- def __init__(self):
38
- super().__init__({})
39
-
40
- @torch.no_grad
41
- def step(self, vars) -> _ClosureType: # type:ignore
42
- if self.next_module is not None: raise ValueError("ReturnClosure can't have children")
43
- if vars.closure is None:
44
- raise ValueError("MakeClosure requires closure")
45
- return vars.closure
46
-
@@ -1,10 +0,0 @@
1
- r"""
2
- This module includes various basic operators, notable LR for setting the learning rate,
3
- as well as gradient/update clipping and normalization.
4
- """
5
-
6
- from .basic import Clone, Fill, Grad, Identity, Lambda, Zeros, Alpha, GradToUpdate, MakeClosure
7
- from .lr import LR
8
- from .on_increase import NegateOnLossIncrease
9
- from .multistep import Multistep
10
- from .accumulate import Accumulate
@@ -1,43 +0,0 @@
1
- from collections.abc import Callable, Iterable
2
-
3
- import torch
4
-
5
- from ...tensorlist import TensorList
6
-
7
- from ...core import OptimizerModule
8
-
9
-
10
- class Accumulate(OptimizerModule):
11
- """Accumulates update over n steps, and steps once updates have been accumulated.
12
- Put this as the first module to get gradient accumulation.
13
-
14
- Args:
15
- n_steps (int): number of steps (batches) to accumulate the update over.
16
- mean (bool, optional):
17
- If True, divides accumulated gradients by number of step,
18
- since most loss functions calculate the mean of all samples
19
- over batch. Defaults to True.
20
- """
21
- def __init__(self, n_steps: int, mean = True):
22
-
23
- super().__init__({})
24
- self.n_steps = n_steps
25
- self.mean = mean
26
- self.cur_step = 0
27
-
28
- @torch.no_grad
29
- def step(self, vars):
30
- self.cur_step += 1
31
-
32
- params = self.get_params()
33
- accumulated_update = self.get_state_key('accumulated_grads')
34
- accumulated_update += vars.maybe_use_grad_(params)
35
-
36
- if self.cur_step % self.n_steps == 0:
37
- vars.ascent = accumulated_update.clone()
38
- if self.mean: vars.ascent /= self.n_steps
39
- accumulated_update.zero_()
40
- return self._update_params_or_step_with_next(vars)
41
-
42
-
43
- return vars.get_loss()