torchzero 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. torchzero/__init__.py +4 -0
  2. torchzero/core/__init__.py +13 -0
  3. torchzero/core/module.py +471 -0
  4. torchzero/core/tensorlist_optimizer.py +219 -0
  5. torchzero/modules/__init__.py +21 -0
  6. torchzero/modules/adaptive/__init__.py +4 -0
  7. torchzero/modules/adaptive/adaptive.py +192 -0
  8. torchzero/modules/experimental/__init__.py +19 -0
  9. torchzero/modules/experimental/experimental.py +294 -0
  10. torchzero/modules/experimental/quad_interp.py +104 -0
  11. torchzero/modules/experimental/subspace.py +259 -0
  12. torchzero/modules/gradient_approximation/__init__.py +7 -0
  13. torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
  14. torchzero/modules/gradient_approximation/base_approximator.py +110 -0
  15. torchzero/modules/gradient_approximation/fdm.py +125 -0
  16. torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
  17. torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
  18. torchzero/modules/gradient_approximation/rfdm.py +125 -0
  19. torchzero/modules/line_search/__init__.py +30 -0
  20. torchzero/modules/line_search/armijo.py +56 -0
  21. torchzero/modules/line_search/base_ls.py +139 -0
  22. torchzero/modules/line_search/directional_newton.py +217 -0
  23. torchzero/modules/line_search/grid_ls.py +158 -0
  24. torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
  25. torchzero/modules/meta/__init__.py +12 -0
  26. torchzero/modules/meta/alternate.py +65 -0
  27. torchzero/modules/meta/grafting.py +195 -0
  28. torchzero/modules/meta/optimizer_wrapper.py +173 -0
  29. torchzero/modules/meta/return_overrides.py +46 -0
  30. torchzero/modules/misc/__init__.py +10 -0
  31. torchzero/modules/misc/accumulate.py +43 -0
  32. torchzero/modules/misc/basic.py +115 -0
  33. torchzero/modules/misc/lr.py +96 -0
  34. torchzero/modules/misc/multistep.py +51 -0
  35. torchzero/modules/misc/on_increase.py +53 -0
  36. torchzero/modules/momentum/__init__.py +4 -0
  37. torchzero/modules/momentum/momentum.py +106 -0
  38. torchzero/modules/operations/__init__.py +29 -0
  39. torchzero/modules/operations/multi.py +298 -0
  40. torchzero/modules/operations/reduction.py +134 -0
  41. torchzero/modules/operations/singular.py +113 -0
  42. torchzero/modules/optimizers/__init__.py +10 -0
  43. torchzero/modules/optimizers/adagrad.py +49 -0
  44. torchzero/modules/optimizers/adam.py +118 -0
  45. torchzero/modules/optimizers/lion.py +28 -0
  46. torchzero/modules/optimizers/rmsprop.py +51 -0
  47. torchzero/modules/optimizers/rprop.py +99 -0
  48. torchzero/modules/optimizers/sgd.py +54 -0
  49. torchzero/modules/orthogonalization/__init__.py +2 -0
  50. torchzero/modules/orthogonalization/newtonschulz.py +159 -0
  51. torchzero/modules/orthogonalization/svd.py +86 -0
  52. torchzero/modules/quasi_newton/__init__.py +4 -0
  53. torchzero/modules/regularization/__init__.py +22 -0
  54. torchzero/modules/regularization/dropout.py +34 -0
  55. torchzero/modules/regularization/noise.py +77 -0
  56. torchzero/modules/regularization/normalization.py +328 -0
  57. torchzero/modules/regularization/ortho_grad.py +78 -0
  58. torchzero/modules/regularization/weight_decay.py +92 -0
  59. torchzero/modules/scheduling/__init__.py +2 -0
  60. torchzero/modules/scheduling/lr_schedulers.py +131 -0
  61. torchzero/modules/scheduling/step_size.py +80 -0
  62. torchzero/modules/second_order/__init__.py +4 -0
  63. torchzero/modules/second_order/newton.py +165 -0
  64. torchzero/modules/smoothing/__init__.py +5 -0
  65. torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
  66. torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
  67. torchzero/modules/weight_averaging/__init__.py +2 -0
  68. torchzero/modules/weight_averaging/ema.py +72 -0
  69. torchzero/modules/weight_averaging/swa.py +171 -0
  70. torchzero/optim/__init__.py +10 -0
  71. torchzero/optim/experimental/__init__.py +20 -0
  72. torchzero/optim/experimental/experimental.py +343 -0
  73. torchzero/optim/experimental/ray_search.py +83 -0
  74. torchzero/optim/first_order/__init__.py +18 -0
  75. torchzero/optim/first_order/cautious.py +158 -0
  76. torchzero/optim/first_order/forward_gradient.py +70 -0
  77. torchzero/optim/first_order/optimizers.py +570 -0
  78. torchzero/optim/modular.py +132 -0
  79. torchzero/optim/quasi_newton/__init__.py +1 -0
  80. torchzero/optim/quasi_newton/directional_newton.py +58 -0
  81. torchzero/optim/second_order/__init__.py +1 -0
  82. torchzero/optim/second_order/newton.py +94 -0
  83. torchzero/optim/wrappers/__init__.py +0 -0
  84. torchzero/optim/wrappers/nevergrad.py +113 -0
  85. torchzero/optim/wrappers/nlopt.py +165 -0
  86. torchzero/optim/wrappers/scipy.py +439 -0
  87. torchzero/optim/zeroth_order/__init__.py +4 -0
  88. torchzero/optim/zeroth_order/fdm.py +87 -0
  89. torchzero/optim/zeroth_order/newton_fdm.py +146 -0
  90. torchzero/optim/zeroth_order/rfdm.py +217 -0
  91. torchzero/optim/zeroth_order/rs.py +85 -0
  92. torchzero/random/__init__.py +1 -0
  93. torchzero/random/random.py +46 -0
  94. torchzero/tensorlist.py +819 -0
  95. torchzero/utils/__init__.py +0 -0
  96. torchzero/utils/compile.py +39 -0
  97. torchzero/utils/derivatives.py +99 -0
  98. torchzero/utils/python_tools.py +25 -0
  99. torchzero/utils/torch_tools.py +92 -0
  100. torchzero-0.0.1.dist-info/LICENSE +21 -0
  101. torchzero-0.0.1.dist-info/METADATA +118 -0
  102. torchzero-0.0.1.dist-info/RECORD +104 -0
  103. torchzero-0.0.1.dist-info/WHEEL +5 -0
  104. torchzero-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,77 @@
1
+ from collections import abc
2
+ from typing import Literal
3
+
4
+ import torch
5
+
6
+ from ...core import OptimizerModule
7
+ from ...tensorlist import Distributions, TensorList, _Scalar, _ScalarSequence
8
+
9
+
10
+ def add_noise_(
11
+ grads: abc.Iterable[torch.Tensor],
12
+ alpha: "_Scalar | _ScalarSequence" = 1e-2,
13
+ distribution: Distributions = "normal",
14
+ mode: Literal["absolute", "global", "param", "channel"] = "param",
15
+ ):
16
+ if not isinstance(grads, TensorList): grads = TensorList(grads)
17
+ if mode == 'absolute':
18
+ grads += grads.sample_like(alpha, distribution)
19
+
20
+ elif mode == 'global':
21
+ grads += grads.sample_like((grads.total_vector_norm(1)/grads.total_numel() * alpha).detach().cpu().item(), distribution)
22
+
23
+ elif mode == 'param':
24
+ grads += grads.sample_like(grads.abs().mean()*alpha, distribution)
25
+
26
+ elif mode == 'channel':
27
+ grads = grads.unbind_channels()
28
+ grads += grads.sample_like(grads.abs().mean()*alpha, distribution)
29
+
30
+ class AddNoise(OptimizerModule):
31
+ """Add noise to update. By default noise magnitude is relative to the mean of each parameter.
32
+
33
+ Args:
34
+ alpha (float, optional): magnitude of noise. Defaults to 1e-2.
35
+ distribution (Distributions, optional): distribution of noise. Defaults to 'normal'.
36
+ mode (str, optional):
37
+ how to calculate noise magnitude.
38
+
39
+ - "absolute": ignores gradient magnitude and always uses `alpha` as magnitude.
40
+
41
+ - "global": multiplies `alpha` by mean of the entire gradient, as if it was a single vector.
42
+
43
+ - "param": multiplies `alpha` by mean of each individual parameter (default).
44
+
45
+ - "channel": multiplies `alpha` by mean of each channel of each parameter.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ alpha: float = 1.,
51
+ distribution: Distributions = "normal",
52
+ mode: Literal["absolute", "global", "param", "channel"] = "param",
53
+ ):
54
+ defaults = dict(alpha = alpha)
55
+ super().__init__(defaults)
56
+ self.distribution: Distributions = distribution
57
+ self.mode: Literal["absolute", "global", "param", "channel"] = mode
58
+
59
+ @torch.no_grad
60
+ def _update(self, state, ascent):
61
+ alpha = self.get_group_key('alpha')
62
+
63
+ add_noise_(ascent, alpha, self.distribution, self.mode)
64
+ return ascent
65
+
66
+ class Random(OptimizerModule):
67
+ """uses a random vector as the update. The vector is completely random and isn't checked to be descent direction.
68
+ This is therefore mainly useful in combination with other modules like Sum, Multiply, etc."""
69
+ def __init__(self, alpha: float = 1, distribution: Distributions = "normal"):
70
+ defaults = dict(alpha = alpha)
71
+ super().__init__(defaults)
72
+ self.distribution: Distributions = distribution
73
+
74
+ @torch.no_grad
75
+ def _update(self, state, ascent):
76
+ alpha = self.get_group_key('alpha')
77
+ return ascent.sample_like(alpha, self.distribution)
@@ -0,0 +1,328 @@
1
+ from collections import abc
2
+ import typing
3
+ import torch
4
+
5
+ from ...tensorlist import TensorList
6
+ from ...core import OptimizerModule
7
+
8
+
9
+ def _normalize_grad_(
10
+ grads: abc.Iterable[torch.Tensor],
11
+ norm_value: float = 1,
12
+ ord: float = 2,
13
+ min: float = 0,
14
+ mode: typing.Literal["global", "param", "channel"] = "param",
15
+ min_numel=2,
16
+ ):
17
+ if mode in ('param', 'channel'):
18
+ for grad in grads:
19
+ if grad.numel() >= min_numel:
20
+ if mode == 'channel' and grad.ndim >= 2:
21
+ norm = torch.linalg.vector_norm(grad, ord, dim=tuple(range(1, grad.ndim)), keepdim=True) # pylint:disable=not-callable
22
+ norm[norm<=min] = 1
23
+ grad /= norm / norm_value
24
+ else: # mode = 'param' or 1d grad
25
+ norm = torch.linalg.vector_norm(grad, ord) # pylint:disable=not-callable
26
+ if norm > min:
27
+ grad /= norm / norm_value
28
+ else:
29
+ if not isinstance(grads, TensorList): grads = TensorList(grads)
30
+ norm = grads.total_vector_norm(ord)
31
+ if norm > min:
32
+ grads /= norm / norm_value
33
+
34
+ @torch.no_grad
35
+ def normalize_grad_(
36
+ params: abc.Iterable[torch.Tensor],
37
+ norm_value: float = 1,
38
+ ord: float = 2,
39
+ min: float = 0,
40
+ mode: typing.Literal["global", "param", "channel"] = "global",
41
+ min_numel=2,
42
+ ):
43
+ """Normalizes gradients of an iterable of parameters.
44
+
45
+ Args:
46
+ params (abc.Iterable[torch.Tensor]): parameters that hold gradients to normalize.
47
+ norm_value (float, optional): value to normalize to. Defaults to 1.
48
+ ord (float, optional): order of the norm. Defaults to 2.
49
+ min (float, optional):
50
+ won't normalize when gradient is below this norm, you can increase this
51
+ to avoid amplifying extremely small gradients. Defaults to 0.
52
+ mode (str, optional):
53
+ what to normalize.
54
+
55
+ - "global": normalize the entire gradient, as if it was a single vector.
56
+
57
+ - "param": normalize each param's gradient (default).
58
+
59
+ - "channel": normalize gradient of each channel of each param.
60
+ min_numel (int, optional):
61
+ skips parameters with less than this many elements. This avoids the issue where
62
+ parameters that have a single element always get set to the value of 1.
63
+ Ignored when mode is 'global'.
64
+
65
+ Example:
66
+ >>> normalize_grad_(model.parameters())
67
+ """
68
+ _normalize_grad_(
69
+ (p.grad for p in params if p.grad is not None),
70
+ norm_value = norm_value,
71
+ ord = ord,
72
+ min = min,
73
+ mode = mode,
74
+ min_numel = min_numel,
75
+ )
76
+
77
+ class Normalize(OptimizerModule):
78
+ """Normalizes update to the given norm value.
79
+
80
+ Args:
81
+ norm_value (float, optional): value to normalize to. Defaults to 1.
82
+ ord (float, optional): order of the norm. Defaults to 2.
83
+ min (float, optional):
84
+ won't normalize when gradient is below this norm, you can increase this
85
+ to avoid amplifying extremely small gradients. Defaults to 0.
86
+ mode (str, optional):
87
+ what to normalize.
88
+
89
+ - "global": normalize the entire gradient, as if it was a single vector.
90
+
91
+ - "param": normalize each param's gradient (default).
92
+
93
+ - "channel": normalize gradient of each channel of each param.
94
+ min_numel (int, optional):
95
+ skips parameters with less than this many elements. This avoids the issue where
96
+ parameters that have a single element always get set to the value of 1.
97
+ Ignored when mode is 'global'.
98
+ """
99
+ def __init__(
100
+ self,
101
+ norm_value: float = 1,
102
+ ord: float = 2,
103
+ min: float = 0,
104
+ mode: typing.Literal["global", "param", "channel"] = "param",
105
+ min_numel=2,
106
+ ):
107
+ super().__init__({})
108
+ self.norm_value = norm_value
109
+ self.ord = ord
110
+ self.min = min
111
+ self.mode: typing.Literal["global", "param", "channel"] = mode
112
+ self.min_numel = min_numel
113
+
114
+ @torch.no_grad
115
+ def _update(self, state, ascent):
116
+ _normalize_grad_(
117
+ ascent,
118
+ norm_value = self.norm_value,
119
+ ord = self.ord,
120
+ min = self.min,
121
+ mode = self.mode,
122
+ min_numel = self.min_numel,
123
+ )
124
+ return ascent
125
+
126
+
127
+ def _centralize_grad_(
128
+ grads: abc.Iterable[torch.Tensor],
129
+ mode: typing.Literal["global", "param", "channel"] = "channel",
130
+ min_ndim=2,
131
+ min_numel=2,
132
+ ):
133
+ if mode in ('param', 'channel'):
134
+ if mode == 'channel': min_ndim = max(min_ndim, 2)
135
+ for grad in grads:
136
+ if grad.numel() >= min_numel and grad.ndim > min_ndim:
137
+ if mode == 'channel':
138
+ grad -= grad.mean(dim=tuple(range(1, grad.ndim)), keepdim=True)
139
+ else: # mode = 'param'
140
+ grad -= grad.mean()
141
+ else:
142
+ if not isinstance(grads, TensorList): grads = TensorList(grads)
143
+ grads -= grads.mean()
144
+
145
+ @torch.no_grad
146
+ def centralize_grad_(
147
+ params: abc.Iterable[torch.Tensor],
148
+ mode: typing.Literal["global", "param", "channel"] = "channel",
149
+ min_ndim=2,
150
+ min_numel=2,
151
+ ):
152
+ """Centralizes gradients of an iterable of parameters.
153
+
154
+ Args:
155
+ params (abc.Iterable[torch.Tensor]): parameters that hold gradients to centralize.
156
+ mode (str, optional):
157
+ what to centralize.
158
+
159
+ - "global": centralize the entire gradient (uses mean of entire gradient).
160
+
161
+ - "param": centralize each param's gradient.
162
+
163
+ - "channel": centralize gradient of each channel of each param (default).
164
+ min_numel (int, optional):
165
+ skips parameters with less than this many elements. This avoids negating updates for
166
+ parameters that have a single element since subtracting mean always makes it 0.
167
+ Ignored when mode is 'global'.
168
+ min_ndim (int, optional):
169
+ skips parameters with less than this many dimensions.
170
+ bias usually has 1 dimension and you don't want to centralize it.
171
+ Ignored when mode is 'global'.
172
+
173
+ reference
174
+ *Yong, H., Huang, J., Hua, X., & Zhang, L. (2020).
175
+ Gradient centralization: A new optimization technique for deep neural networks.
176
+ In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK,
177
+ August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing.*
178
+
179
+ Example:
180
+ >>> centralize_grad_(model.parameters())
181
+ """
182
+ _centralize_grad_(
183
+ (p.grad for p in params if p.grad is not None),
184
+ mode = mode,
185
+ min_ndim = min_ndim,
186
+ min_numel = min_numel,
187
+ )
188
+
189
+ class Centralize(OptimizerModule):
190
+ """Centralizes the update.
191
+
192
+ Args:
193
+ mode (str, optional):
194
+ what to centralize.
195
+
196
+ - "global": centralize the entire gradient (uses mean of entire gradient).
197
+
198
+ - "param": centralize each param's gradient.
199
+
200
+ - "channel": centralize gradient of each channel of each param (default).
201
+ min_numel (int, optional):
202
+ skips parameters with less than this many elements. This avoids negating updates for
203
+ parameters that have a single element since subtracting mean always makes it 0.
204
+ Ignored when mode is 'global'.
205
+ min_ndim (int, optional):
206
+ skips parameters with less than this many dimensions.
207
+ bias usually has 1 dimension and you don't want to centralize it.
208
+ Ignored when mode is 'global'.
209
+
210
+ reference
211
+ *Yong, H., Huang, J., Hua, X., & Zhang, L. (2020).
212
+ Gradient centralization: A new optimization technique for deep neural networks.
213
+ In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK,
214
+ August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing.*
215
+ """
216
+ def __init__(
217
+ self,
218
+ mode: typing.Literal["global", "param", "channel"] = "channel",
219
+ min_ndim=2,
220
+ min_numel=2,
221
+ ):
222
+ super().__init__({})
223
+ self.mode: typing.Literal["global", "param", "channel"] = mode
224
+ self.min_ndim = min_ndim
225
+ self.min_numel = min_numel
226
+
227
+ @torch.no_grad
228
+ def _update(self, state, ascent):
229
+ _centralize_grad_(
230
+ ascent,
231
+ mode = self.mode,
232
+ min_ndim = self.min_ndim,
233
+ min_numel = self.min_numel,
234
+ )
235
+ return ascent
236
+
237
+
238
+ def clip_grad_value_(params: abc.Iterable[torch.Tensor], value:float):
239
+ """Clip the gradients of an iterable of parameters at specified value.
240
+
241
+ Args:
242
+ params (abc.Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor that will have gradients clipped.
243
+ value (float, optional):
244
+ maximum allowed magnitude of the gradients.
245
+ The gradients are clipped in the range `[-clip_value, clip_value]`
246
+ """
247
+ TensorList(params).get_existing_grads().clamp_(-value, value)
248
+
249
+ class ClipValue(OptimizerModule):
250
+ """Clip the update at specified value.
251
+
252
+ Args:
253
+ value (float, optional): maximum allowed magnitude of the gradients.
254
+ The gradients are clipped in the range `[-clip_value, clip_value]`
255
+ """
256
+ def __init__(self, value: float):
257
+ defaults = dict(value = value)
258
+ super().__init__(defaults)
259
+
260
+ @torch.no_grad
261
+ def _update(self, state, ascent):
262
+ value = self.get_group_key('value')
263
+ ascent.clamp_(-value, value)
264
+ return ascent
265
+
266
+ def clip_grad_norm_(
267
+ params: abc.Iterable[torch.Tensor],
268
+ max_norm: float,
269
+ ord: float = 2,
270
+ mode: typing.Literal["global", "param", "channel"] = "param",
271
+ ):
272
+ """Clip the gradient norm of an iterable of parameters.
273
+
274
+ Args:
275
+ params (abc.Iterable[torch.Tensor]): parameters that hold gradients to clip the norm of.
276
+ max_norm (float, optional): norm value to clip to.
277
+ ord (float, optional): order of the norm. Defaults to 2.
278
+ mode (str, optional):
279
+ what to calculate the norm over.
280
+
281
+ - "global": calculates and clips the norm of the entire gradient, as if it was a single vector.
282
+
283
+ - "param": calculates and clips each param's gradient norm (default).
284
+
285
+ - "channel": calculate and clip the norm of gradient of each channel of each param.
286
+
287
+ Example:
288
+ >>> clip_grad_norm_(model.parameters())
289
+ """
290
+ _normalize_grad_(
291
+ (p.grad for p in params if p.grad is not None),
292
+ norm_value = max_norm,
293
+ min = max_norm,
294
+ ord = ord,
295
+ mode = mode,
296
+ )
297
+
298
+ class ClipNorm(OptimizerModule):
299
+ """Clip the gradient norm of an iterable of parameters.
300
+
301
+ Args:
302
+ max_norm (float, optional): norm value to clip to.
303
+ ord (float, optional): order of the norm. Defaults to 2.
304
+ mode (str, optional):
305
+ what to calculate the norm over.
306
+
307
+ - "global": calculates and clips the norm of the entire gradient, as if it was a single vector.
308
+
309
+ - "param": calculates and clips each param's gradient norm (default).
310
+
311
+ - "channel": calculate and clip the norm of gradient of each channel of each param.
312
+ """
313
+ def __init__(self, max_norm: float, ord:float=2, mode: typing.Literal["global", "param", "channel"] = "param",):
314
+ super().__init__({})
315
+ self.max_norm = max_norm
316
+ self.ord = ord
317
+ self.mode: typing.Literal["global", "param", "channel"] = mode
318
+
319
+ @torch.no_grad
320
+ def _update(self, state, ascent):
321
+ _normalize_grad_(
322
+ ascent,
323
+ norm_value = self.max_norm,
324
+ min = self.max_norm,
325
+ ord = self.ord,
326
+ mode = self.mode,
327
+ )
328
+ return ascent
@@ -0,0 +1,78 @@
1
+ """
2
+ ⟂Grad (read “ortho-grad”) was proposed in https://arxiv.org/abs/2501.04697.
3
+
4
+ """
5
+ from collections.abc import Iterable
6
+
7
+ import torch
8
+
9
+ from ...tensorlist import TensorList
10
+ from ...core import OptimizerModule, _Targets
11
+
12
+
13
+ def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
14
+ """Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
15
+
16
+ Args:
17
+ params (abc.Iterable[torch.Tensor]): parameters that hold gradients to apply ⟂Grad to.
18
+ eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
19
+
20
+ reference
21
+ https://arxiv.org/abs/2501.04697
22
+ """
23
+ if not isinstance(params, TensorList): params = TensorList(params)
24
+ params = params.with_grad()
25
+ grad = params.grad
26
+ grad -= (((params*grad).total_sum())/(params*params).total_sum() + eps) * params
27
+
28
+ class OrthoGrad(OptimizerModule):
29
+ """⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
30
+
31
+ Args:
32
+ params (abc.Iterable[torch.Tensor]): parameters that hold gradients to apply ⟂Grad to.
33
+ eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
34
+ renormalize (bool, optional): whether to renormalize gradients back to original norm (default: True).
35
+ sqrt_scale (bool, optional):
36
+ uses square root of the scale to make it more impactful, experimental setting and doesn't really work (default: False).
37
+ add (bool, optional):
38
+ Experimental option that changes subtraction to addition.
39
+ I don't think it has any geometric meaning but it drives weights towards zero instead of away from it.
40
+ and it seems to work well with sqrt_scale = True. It speeds up convergence by a lot compared to using vanilla gradient,
41
+ but also has INSANE overfitting.
42
+ target (str, optional):
43
+ determines what this module updates.
44
+
45
+ "ascent" - it updates the ascent (default).
46
+
47
+ "grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
48
+
49
+ "closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
50
+
51
+ reference
52
+ https://arxiv.org/abs/2501.04697
53
+ """
54
+ def __init__(self, eps: float = 1e-30, renormalize=True, sqrt_scale = False, add=False, target: _Targets = 'ascent'):
55
+ super().__init__({}, target=target)
56
+ self.eps = eps
57
+ self.add = add
58
+ self.renormalize = renormalize
59
+ self.sqrt_scale = sqrt_scale
60
+
61
+ def _update(self, state, ascent):
62
+ params = self.get_params()
63
+
64
+ if self.renormalize: orig_norm = ascent.norm(2) + self.eps
65
+ else: orig_norm = 1
66
+
67
+ scale = (params*ascent).total_sum() / ((params*params).total_sum() + self.eps)
68
+ if self.sqrt_scale:
69
+ scale = scale.abs().sqrt() * scale.sign()
70
+
71
+ if self.add: ascent += params * scale
72
+ else: ascent -= params * scale
73
+
74
+ if self.renormalize:
75
+ ascent *= (orig_norm / ascent.norm(2))
76
+
77
+ return ascent
78
+
@@ -0,0 +1,92 @@
1
+ from typing import Literal
2
+ from collections.abc import Iterable
3
+
4
+ import torch
5
+
6
+ from ...tensorlist import TensorList
7
+ from ...core import OptimizerModule, _Targets
8
+
9
+
10
+ def l2_regularize_(params: Iterable[torch.Tensor], alpha: float = 1e-2):
11
+ """Adds L2 weight regularization term to the gradients in-place.
12
+
13
+ Args:
14
+ params (Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor.
15
+ alpha (float, optional): multiplier to the regularizer. Defaults to 1e-2.
16
+ """
17
+ p = TensorList(params).with_requires_grad()
18
+ p.ensure_grad_()
19
+ p.grad.add_(p, alpha = alpha)
20
+
21
+ def l1_regularize_(params: Iterable[torch.Tensor], alpha: float = 1e-2):
22
+ """Adds L1 weight regularization term to the gradients in-place.
23
+
24
+ Args:
25
+ params (Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor.
26
+ alpha (float, optional): multiplier to the regularizer. Defaults to 1e-2.
27
+ """
28
+ p = TensorList(params).with_requires_grad()
29
+ p.ensure_grad_()
30
+ p.grad.add_(p.sign(), alpha = alpha)
31
+
32
+ def weight_decay_penalty(params: Iterable[torch.Tensor], alpha: float = 1e-2, ord:float = 2):
33
+ """Calculate the weight decay penalty term that can be added to the loss.
34
+
35
+ Args:
36
+ params (Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor.
37
+ alpha (float): multiplier to the regularizer.
38
+ ord (int, optional): order of the norm. Defaults to 2.
39
+ """
40
+ return TensorList(params).norm(ord) * alpha
41
+
42
+ def decay_weights_(params: Iterable[torch.Tensor], alpha: float = 1e-2, ord:Literal[1, 2] = 2):
43
+ """Apply weight decay directly to parameters in-place.
44
+
45
+ Args:
46
+ params (Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor to decay.
47
+ alpha (float): by how much to decay parameters (default: 1e-2)
48
+ ord (float, optional):
49
+ order of the penalty, 1 and 2 are currently supported (L1 and L2 regularization) (default: 2)
50
+ """
51
+ params = TensorList(params)
52
+ if ord == 2: params.mul_(1-alpha)
53
+ elif ord == 1: params.sub_(params.sign().mul_(alpha))
54
+ else: raise NotImplementedError(f'order {ord} is not supported')
55
+
56
+
57
+ class WeightDecay(OptimizerModule):
58
+ """Adds weight decay term (L1 or L2 regularization) to the ascent direction.
59
+
60
+ Put this at the end to make it decoupled.
61
+
62
+ Args:
63
+ alpha (float, optional): multiplier to the regularizer (default: 1e-2)
64
+ ord (Literal[1, 2], optional):
65
+ order of the penalty, 1 and 2 are currently supported (L1 and L2 regularization).
66
+ Defaults to 2.
67
+ target (str, optional):
68
+ determines what this module updates.
69
+
70
+ "ascent" - it updates the ascent
71
+
72
+ "grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
73
+
74
+ "closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
75
+ """
76
+ def __init__(self, alpha: float = 1e-2, ord:Literal[1, 2] = 2, target: _Targets = "ascent"):
77
+ defaults = dict(alpha = alpha)
78
+ super().__init__(defaults, target = target)
79
+ self.ord = ord
80
+
81
+ @torch.no_grad
82
+ def _update(self, state, ascent):
83
+ params = self.get_params()
84
+ alpha = self.get_group_key('alpha')
85
+
86
+ if any(i != 0 for i in alpha):
87
+
88
+ if self.ord == 1: ascent.add_(params.sign() * alpha)
89
+ elif self.ord == 2: ascent.add_(params * alpha)
90
+ else: raise NotImplementedError(f'weight descent of order {self.ord} not implemented.')
91
+
92
+ return ascent
@@ -0,0 +1,2 @@
1
+ from .lr_schedulers import LRWarmup
2
+ from .step_size import PolyakStepSize, RandomStepSize