torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -1,328 +0,0 @@
1
- from collections import abc
2
- import typing
3
- import torch
4
-
5
- from ...tensorlist import TensorList
6
- from ...core import OptimizerModule
7
-
8
-
9
- def _normalize_grad_(
10
- grads: abc.Iterable[torch.Tensor],
11
- norm_value: float = 1,
12
- ord: float = 2,
13
- min: float = 0,
14
- mode: typing.Literal["global", "param", "channel"] = "param",
15
- min_numel=2,
16
- ):
17
- if mode in ('param', 'channel'):
18
- for grad in grads:
19
- if grad.numel() >= min_numel:
20
- if mode == 'channel' and grad.ndim >= 2:
21
- norm = torch.linalg.vector_norm(grad, ord, dim=tuple(range(1, grad.ndim)), keepdim=True) # pylint:disable=not-callable
22
- norm[norm<=min] = 1
23
- grad /= norm / norm_value
24
- else: # mode = 'param' or 1d grad
25
- norm = torch.linalg.vector_norm(grad, ord) # pylint:disable=not-callable
26
- if norm > min:
27
- grad /= norm / norm_value
28
- else:
29
- if not isinstance(grads, TensorList): grads = TensorList(grads)
30
- norm = grads.total_vector_norm(ord)
31
- if norm > min:
32
- grads /= norm / norm_value # type:ignore
33
-
34
- @torch.no_grad
35
- def normalize_grad_(
36
- params: abc.Iterable[torch.Tensor],
37
- norm_value: float = 1,
38
- ord: float = 2,
39
- min: float = 0,
40
- mode: typing.Literal["global", "param", "channel"] = "global",
41
- min_numel=2,
42
- ):
43
- """Normalizes gradients of an iterable of parameters.
44
-
45
- Args:
46
- params (abc.Iterable[torch.Tensor]): parameters that hold gradients to normalize.
47
- norm_value (float, optional): value to normalize to. Defaults to 1.
48
- ord (float, optional): order of the norm. Defaults to 2.
49
- min (float, optional):
50
- won't normalize when gradient is below this norm, you can increase this
51
- to avoid amplifying extremely small gradients. Defaults to 0.
52
- mode (str, optional):
53
- what to normalize.
54
-
55
- - "global": normalize the entire gradient, as if it was a single vector.
56
-
57
- - "param": normalize each param's gradient (default).
58
-
59
- - "channel": normalize gradient of each channel of each param.
60
- min_numel (int, optional):
61
- skips parameters with less than this many elements. This avoids the issue where
62
- parameters that have a single element always get set to the value of 1.
63
- Ignored when mode is 'global'.
64
-
65
- Example:
66
- >>> normalize_grad_(model.parameters())
67
- """
68
- _normalize_grad_(
69
- (p.grad for p in params if p.grad is not None),
70
- norm_value = norm_value,
71
- ord = ord,
72
- min = min,
73
- mode = mode,
74
- min_numel = min_numel,
75
- )
76
-
77
- class Normalize(OptimizerModule):
78
- """Normalizes update to the given norm value.
79
-
80
- Args:
81
- norm_value (float, optional): value to normalize to. Defaults to 1.
82
- ord (float, optional): order of the norm. Defaults to 2.
83
- min (float, optional):
84
- won't normalize when gradient is below this norm, you can increase this
85
- to avoid amplifying extremely small gradients. Defaults to 0.
86
- mode (str, optional):
87
- what to normalize.
88
-
89
- - "global": normalize the entire gradient, as if it was a single vector.
90
-
91
- - "param": normalize each param's gradient (default).
92
-
93
- - "channel": normalize gradient of each channel of each param.
94
- min_numel (int, optional):
95
- skips parameters with less than this many elements. This avoids the issue where
96
- parameters that have a single element always get set to the value of 1.
97
- Ignored when mode is 'global'.
98
- """
99
- def __init__(
100
- self,
101
- norm_value: float = 1,
102
- ord: float = 2,
103
- min: float = 0,
104
- mode: typing.Literal["global", "param", "channel"] = "param",
105
- min_numel=2,
106
- ):
107
- super().__init__({})
108
- self.norm_value = norm_value
109
- self.ord = ord
110
- self.min = min
111
- self.mode: typing.Literal["global", "param", "channel"] = mode
112
- self.min_numel = min_numel
113
-
114
- @torch.no_grad
115
- def _update(self, vars, ascent):
116
- _normalize_grad_(
117
- ascent,
118
- norm_value = self.norm_value,
119
- ord = self.ord,
120
- min = self.min,
121
- mode = self.mode,
122
- min_numel = self.min_numel,
123
- )
124
- return ascent
125
-
126
-
127
- def _centralize_grad_(
128
- grads: abc.Iterable[torch.Tensor],
129
- mode: typing.Literal["global", "param", "channel"] = "channel",
130
- min_ndim=2,
131
- min_numel=2,
132
- ):
133
- if mode in ('param', 'channel'):
134
- if mode == 'channel': min_ndim = max(min_ndim, 2)
135
- for grad in grads:
136
- if grad.numel() >= min_numel and grad.ndim > min_ndim:
137
- if mode == 'channel':
138
- grad -= grad.mean(dim=tuple(range(1, grad.ndim)), keepdim=True)
139
- else: # mode = 'param'
140
- grad -= grad.mean()
141
- else:
142
- if not isinstance(grads, TensorList): grads = TensorList(grads)
143
- grads -= grads.mean()
144
-
145
- @torch.no_grad
146
- def centralize_grad_(
147
- params: abc.Iterable[torch.Tensor],
148
- mode: typing.Literal["global", "param", "channel"] = "channel",
149
- min_ndim=2,
150
- min_numel=2,
151
- ):
152
- """Centralizes gradients of an iterable of parameters.
153
-
154
- Args:
155
- params (abc.Iterable[torch.Tensor]): parameters that hold gradients to centralize.
156
- mode (str, optional):
157
- what to centralize.
158
-
159
- - "global": centralize the entire gradient (uses mean of entire gradient).
160
-
161
- - "param": centralize each param's gradient.
162
-
163
- - "channel": centralize gradient of each channel of each param (default).
164
- min_numel (int, optional):
165
- skips parameters with less than this many elements. This avoids negating updates for
166
- parameters that have a single element since subtracting mean always makes it 0.
167
- Ignored when mode is 'global'.
168
- min_ndim (int, optional):
169
- skips parameters with less than this many dimensions.
170
- bias usually has 1 dimension and you don't want to centralize it.
171
- Ignored when mode is 'global'.
172
-
173
- reference
174
- *Yong, H., Huang, J., Hua, X., & Zhang, L. (2020).
175
- Gradient centralization: A new optimization technique for deep neural networks.
176
- In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK,
177
- August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing.*
178
-
179
- Example:
180
- >>> centralize_grad_(model.parameters())
181
- """
182
- _centralize_grad_(
183
- (p.grad for p in params if p.grad is not None),
184
- mode = mode,
185
- min_ndim = min_ndim,
186
- min_numel = min_numel,
187
- )
188
-
189
- class Centralize(OptimizerModule):
190
- """Centralizes the update.
191
-
192
- Args:
193
- mode (str, optional):
194
- what to centralize.
195
-
196
- - "global": centralize the entire gradient (uses mean of entire gradient).
197
-
198
- - "param": centralize each param's gradient.
199
-
200
- - "channel": centralize gradient of each channel of each param (default).
201
- min_numel (int, optional):
202
- skips parameters with less than this many elements. This avoids negating updates for
203
- parameters that have a single element since subtracting mean always makes it 0.
204
- Ignored when mode is 'global'.
205
- min_ndim (int, optional):
206
- skips parameters with less than this many dimensions.
207
- bias usually has 1 dimension and you don't want to centralize it.
208
- Ignored when mode is 'global'.
209
-
210
- reference
211
- *Yong, H., Huang, J., Hua, X., & Zhang, L. (2020).
212
- Gradient centralization: A new optimization technique for deep neural networks.
213
- In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK,
214
- August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing.*
215
- """
216
- def __init__(
217
- self,
218
- mode: typing.Literal["global", "param", "channel"] = "channel",
219
- min_ndim=2,
220
- min_numel=2,
221
- ):
222
- super().__init__({})
223
- self.mode: typing.Literal["global", "param", "channel"] = mode
224
- self.min_ndim = min_ndim
225
- self.min_numel = min_numel
226
-
227
- @torch.no_grad
228
- def _update(self, vars, ascent):
229
- _centralize_grad_(
230
- ascent,
231
- mode = self.mode,
232
- min_ndim = self.min_ndim,
233
- min_numel = self.min_numel,
234
- )
235
- return ascent
236
-
237
-
238
- def clip_grad_value_(params: abc.Iterable[torch.Tensor], value:float):
239
- """Clip the gradients of an iterable of parameters at specified value.
240
-
241
- Args:
242
- params (abc.Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor that will have gradients clipped.
243
- value (float, optional):
244
- maximum allowed magnitude of the gradients.
245
- The gradients are clipped in the range `[-clip_value, clip_value]`
246
- """
247
- TensorList(params).get_existing_grads().clamp_(-value, value)
248
-
249
- class ClipValue(OptimizerModule):
250
- """Clip the update at specified value.
251
-
252
- Args:
253
- value (float, optional): maximum allowed magnitude of the gradients.
254
- The gradients are clipped in the range `[-clip_value, clip_value]`
255
- """
256
- def __init__(self, value: float):
257
- defaults = dict(value = value)
258
- super().__init__(defaults)
259
-
260
- @torch.no_grad
261
- def _update(self, vars, ascent):
262
- value = self.get_group_key('value')
263
- ascent.clamp_(-value, value)
264
- return ascent
265
-
266
- def clip_grad_norm_(
267
- params: abc.Iterable[torch.Tensor],
268
- max_norm: float,
269
- ord: float = 2,
270
- mode: typing.Literal["global", "param", "channel"] = "param",
271
- ):
272
- """Clip the gradient norm of an iterable of parameters.
273
-
274
- Args:
275
- params (abc.Iterable[torch.Tensor]): parameters that hold gradients to clip the norm of.
276
- max_norm (float, optional): norm value to clip to.
277
- ord (float, optional): order of the norm. Defaults to 2.
278
- mode (str, optional):
279
- what to calculate the norm over.
280
-
281
- - "global": calculates and clips the norm of the entire gradient, as if it was a single vector.
282
-
283
- - "param": calculates and clips each param's gradient norm (default).
284
-
285
- - "channel": calculate and clip the norm of gradient of each channel of each param.
286
-
287
- Example:
288
- >>> clip_grad_norm_(model.parameters())
289
- """
290
- _normalize_grad_(
291
- (p.grad for p in params if p.grad is not None),
292
- norm_value = max_norm,
293
- min = max_norm,
294
- ord = ord,
295
- mode = mode,
296
- )
297
-
298
- class ClipNorm(OptimizerModule):
299
- """Clip the gradient norm of an iterable of parameters.
300
-
301
- Args:
302
- max_norm (float, optional): norm value to clip to.
303
- ord (float, optional): order of the norm. Defaults to 2.
304
- mode (str, optional):
305
- what to calculate the norm over.
306
-
307
- - "global": calculates and clips the norm of the entire gradient, as if it was a single vector.
308
-
309
- - "param": calculates and clips each param's gradient norm (default).
310
-
311
- - "channel": calculate and clip the norm of gradient of each channel of each param.
312
- """
313
- def __init__(self, max_norm: float, ord:float=2, mode: typing.Literal["global", "param", "channel"] = "param",):
314
- super().__init__({})
315
- self.max_norm = max_norm
316
- self.ord = ord
317
- self.mode: typing.Literal["global", "param", "channel"] = mode
318
-
319
- @torch.no_grad
320
- def _update(self, vars, ascent):
321
- _normalize_grad_(
322
- ascent,
323
- norm_value = self.max_norm,
324
- min = self.max_norm,
325
- ord = self.ord,
326
- mode = self.mode,
327
- )
328
- return ascent
@@ -1,78 +0,0 @@
1
- """
2
- ⟂Grad (read “ortho-grad”) was proposed in https://arxiv.org/abs/2501.04697.
3
-
4
- """
5
- from collections.abc import Iterable
6
-
7
- import torch
8
-
9
- from ...tensorlist import TensorList
10
- from ...core import OptimizerModule, _Targets
11
-
12
-
13
- def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
14
- """Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
15
-
16
- Args:
17
- params (abc.Iterable[torch.Tensor]): parameters that hold gradients to apply ⟂Grad to.
18
- eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
19
-
20
- reference
21
- https://arxiv.org/abs/2501.04697
22
- """
23
- if not isinstance(params, TensorList): params = TensorList(params)
24
- params = params.with_grad()
25
- grad = params.grad
26
- grad -= (((params*grad).total_sum())/(params*params).total_sum() + eps) * params
27
-
28
- class OrthoGrad(OptimizerModule):
29
- """⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
30
-
31
- Args:
32
- params (abc.Iterable[torch.Tensor]): parameters that hold gradients to apply ⟂Grad to.
33
- eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
34
- renormalize (bool, optional): whether to renormalize gradients back to original norm (default: True).
35
- sqrt_scale (bool, optional):
36
- uses square root of the scale to make it more impactful, experimental setting and doesn't really work (default: False).
37
- add (bool, optional):
38
- Experimental option that changes subtraction to addition.
39
- I don't think it has any geometric meaning but it drives weights towards zero instead of away from it.
40
- and it seems to work well with sqrt_scale = True. It speeds up convergence by a lot compared to using vanilla gradient,
41
- but also has INSANE overfitting.
42
- target (str, optional):
43
- determines what this module updates.
44
-
45
- "ascent" - it updates the ascent (default).
46
-
47
- "grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
48
-
49
- "closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
50
-
51
- reference
52
- https://arxiv.org/abs/2501.04697
53
- """
54
- def __init__(self, eps: float = 1e-30, renormalize=True, sqrt_scale = False, add=False, target: _Targets = 'ascent'):
55
- super().__init__({}, target=target)
56
- self.eps = eps
57
- self.add = add
58
- self.renormalize = renormalize
59
- self.sqrt_scale = sqrt_scale
60
-
61
- def _update(self, vars, ascent):
62
- params = self.get_params()
63
-
64
- if self.renormalize: orig_norm = ascent.norm(2) + self.eps
65
- else: orig_norm = 1
66
-
67
- scale = (params*ascent).total_sum() / ((params*params).total_sum() + self.eps)
68
- if self.sqrt_scale:
69
- scale = scale.abs().sqrt() * scale.sign()
70
-
71
- if self.add: ascent += params * scale
72
- else: ascent -= params * scale
73
-
74
- if self.renormalize:
75
- ascent *= (orig_norm / ascent.norm(2))
76
-
77
- return ascent
78
-
@@ -1,92 +0,0 @@
1
- from typing import Literal
2
- from collections.abc import Iterable
3
-
4
- import torch
5
-
6
- from ...tensorlist import TensorList
7
- from ...core import OptimizerModule, _Targets
8
-
9
-
10
- def l2_regularize_(params: Iterable[torch.Tensor], alpha: float = 1e-2):
11
- """Adds L2 weight regularization term to the gradients in-place.
12
-
13
- Args:
14
- params (Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor.
15
- alpha (float, optional): multiplier to the regularizer. Defaults to 1e-2.
16
- """
17
- p = TensorList(params).with_requires_grad()
18
- p.ensure_grad_()
19
- p.grad.add_(p, alpha = alpha)
20
-
21
- def l1_regularize_(params: Iterable[torch.Tensor], alpha: float = 1e-2):
22
- """Adds L1 weight regularization term to the gradients in-place.
23
-
24
- Args:
25
- params (Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor.
26
- alpha (float, optional): multiplier to the regularizer. Defaults to 1e-2.
27
- """
28
- p = TensorList(params).with_requires_grad()
29
- p.ensure_grad_()
30
- p.grad.add_(p.sign(), alpha = alpha)
31
-
32
- def weight_decay_penalty(params: Iterable[torch.Tensor], alpha: float = 1e-2, ord:float = 2):
33
- """Calculate the weight decay penalty term that can be added to the loss.
34
-
35
- Args:
36
- params (Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor.
37
- alpha (float): multiplier to the regularizer.
38
- ord (int, optional): order of the norm. Defaults to 2.
39
- """
40
- return TensorList(params).norm(ord) * alpha
41
-
42
- def decay_weights_(params: Iterable[torch.Tensor], alpha: float = 1e-2, ord:Literal[1, 2] = 2):
43
- """Apply weight decay directly to parameters in-place.
44
-
45
- Args:
46
- params (Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor to decay.
47
- alpha (float): by how much to decay parameters (default: 1e-2)
48
- ord (float, optional):
49
- order of the penalty, 1 and 2 are currently supported (L1 and L2 regularization) (default: 2)
50
- """
51
- params = TensorList(params)
52
- if ord == 2: params.mul_(1-alpha)
53
- elif ord == 1: params.sub_(params.sign().mul_(alpha))
54
- else: raise NotImplementedError(f'order {ord} is not supported')
55
-
56
-
57
- class WeightDecay(OptimizerModule):
58
- """Adds weight decay term (L1 or L2 regularization) to the ascent direction.
59
-
60
- Put this at the end to make it decoupled.
61
-
62
- Args:
63
- alpha (float, optional): multiplier to the regularizer (default: 1e-2)
64
- ord (Literal[1, 2], optional):
65
- order of the penalty, 1 and 2 are currently supported (L1 and L2 regularization).
66
- Defaults to 2.
67
- target (str, optional):
68
- determines what this module updates.
69
-
70
- "ascent" - it updates the ascent
71
-
72
- "grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
73
-
74
- "closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
75
- """
76
- def __init__(self, alpha: float = 1e-2, ord:Literal[1, 2] = 2, target: _Targets = "ascent"):
77
- defaults = dict(alpha = alpha)
78
- super().__init__(defaults, target = target)
79
- self.ord = ord
80
-
81
- @torch.no_grad
82
- def _update(self, vars, ascent):
83
- params = self.get_params()
84
- alpha = self.get_group_key('alpha')
85
-
86
- if any(i != 0 for i in alpha):
87
-
88
- if self.ord == 1: ascent.add_(params.sign() * alpha)
89
- elif self.ord == 2: ascent.add_(params * alpha)
90
- else: raise NotImplementedError(f'weight descent of order {self.ord} not implemented.')
91
-
92
- return ascent
@@ -1,2 +0,0 @@
1
- from .lr_schedulers import LRWarmup
2
- from .step_size import PolyakStepSize, RandomStepSize
@@ -1,131 +0,0 @@
1
- from collections.abc import Callable
2
- from functools import partial
3
- from typing import Any, overload, TYPE_CHECKING
4
- import random
5
-
6
- import torch
7
- from ...core import OptimizerModule
8
-
9
-
10
- if TYPE_CHECKING:
11
- from ...optim import Modular
12
-
13
-
14
- # LR SCHEDULING MOVED TO LR MODULE
15
-
16
- # def _set_momentum_hook(optimizer, state, momentum):
17
- # for module in optimizer.unrolled_modules:
18
- # if 'momentum' in module.defaults:
19
- # for g in module.param_groups:
20
- # g['momentum'] = momentum
21
- # if 'beta1' in module.defaults:
22
- # for g in module.param_groups:
23
- # g['beta1'] = momentum
24
-
25
- # def _add_scheduler_hook(opt: "Modular", scheduler_cls, id):
26
- # """post-init hook that sets `scheduler_step_fn` to the scheduler step."""
27
- # # get LR module
28
- # lr_module = opt.get_lr_module()
29
-
30
- # # get current LRScheduler module
31
- # scheds = [i for i in opt.unrolled_modules if isinstance(i, LRScheduler)]
32
- # scheds = [i for i in scheds if i.id == id]
33
- # if len(scheds) != 1:
34
- # raise RuntimeError(f"more than 1 module with id {id}: {scheds}")
35
-
36
- # sch_module = scheds[0]
37
-
38
- # # make a scheduler and save the step function
39
- # scheduler = scheduler_cls(lr_module)
40
- # sch_module.scheduler_step_fn = scheduler.step
41
-
42
-
43
- # class LRScheduler(OptimizerModule):
44
- # """Use any pytorch lr scheduler.
45
-
46
- # Important - the lr is applied multiplicatively and multiplies with learning rate of other modules,
47
- # so usually base learning rate of the lr scheduler, such as `max_lr` for OneCycleLR, should be set to 1.
48
-
49
- # Args:
50
- # lr_scheduler (Callable[[torch.optim.Optimizer], torch.optim.lr_scheduler.LRScheduler | Any]):
51
- # something like:
52
- # .. code:: py
53
- # lambda opt: OneCycleLR(opt, max_lr = 1, total_steps = 60000)
54
- # update_every (int, optional):
55
- # call `step` every n steps, useful for schedulers that only step once per epoch. Defaults to 1.
56
- # cycle_momentum (bool, optional):
57
- # enables support for cycling momentum with schedulers that support it, such as `OneCycleLR`.
58
- # Unlike lr, momentum is not applied multiplicatively, but set to all other modules with
59
- # `momentum` or `beta` settings. Has no effect if there are no modules that support momentum. Defaults to False.
60
- # init_lr (float, optional):
61
- # initial lr, I believe most lr schedulers ignore this. Defaults to 1.
62
- # init_momentum (float, optional):
63
- # initial init_momentum, I believe most lr schedulers ignore this.
64
- # Has no effect if `cycle_momentum` is False or there are no modules that support momentum. Defaults to 0.
65
- # """
66
- # def __init__(
67
- # self,
68
- # lr_scheduler: Callable[[torch.optim.Optimizer], torch.optim.lr_scheduler.LRScheduler | Any],
69
- # step_every: int = 1,
70
- # cycle_momentum: bool = True,
71
- # ):
72
- # super().__init__({})
73
- # scheduler = lr_scheduler(self.dummy_opt)
74
- # self.update_every = step_every
75
- # self.cycle_momentum = cycle_momentum
76
-
77
- # self.scheduler_step_fn = scheduler.step
78
- # self.cur = 0
79
- # self.cur_lr = init_lr
80
- # self.cur_momentum = init_momentum
81
-
82
- # self.id = random.random()
83
-
84
- # def step(self, vars):
85
- # if self.cur % self.update_every == 0:
86
- # self.scheduler_step_fn()
87
- # self.cur_lr = self.dummy_opt.first_param_group['lr']
88
- # self.cur_momentum = self.dummy_opt.first_param_group['momentum']
89
-
90
- # params = self.get_params()
91
- # ascent = state.maybe_use_grad_(params)
92
- # ascent *= self.cur_lr
93
-
94
- # if self.cycle_momentum:
95
- # state.add_post_step_hook(partial(_set_momentum_hook, momentum = self.cur_momentum))
96
-
97
- class LRWarmup(OptimizerModule):
98
- """Linear learning rate warmup.
99
-
100
- Args:
101
- n_steps (int): number of warmup steps.
102
- start_lr (float, optional): initial lr. Defaults to 1e-8.
103
- end_lr (float, optional): final lr. Defaults to 1.
104
- delay_steps (int, optional): number of `start_lr` steps before starting the warmup. Defaults to 0.
105
- """
106
- def __init__(self, n_steps: int, start_lr: float = 1e-8, end_lr: float = 1, delay_steps: int = 0):
107
-
108
- super().__init__({})
109
- self.n_steps = n_steps
110
- self.start_lr = start_lr
111
- self.end_lr = end_lr
112
- self.delay_steps = delay_steps
113
-
114
- self.cur = 0
115
-
116
- def _update(self, vars, ascent):
117
- if self.cur < self.delay_steps:
118
- if self.start_lr != 1: ascent *= self.start_lr
119
-
120
- elif self.cur >= self.n_steps + self.delay_steps:
121
- if self.end_lr != 1: ascent *= self.end_lr
122
-
123
- else:
124
- remaining = (self.n_steps - (self.cur-self.delay_steps)) / self.n_steps
125
- lr = (self.start_lr * remaining) + self.end_lr * (1 - remaining)
126
- ascent *= lr
127
-
128
- self.cur += 1
129
- return ascent
130
-
131
-