torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -0,0 +1,320 @@
1
+ from operator import itemgetter
2
+ from typing import Literal
3
+ from collections.abc import Iterable, Sequence
4
+ import math
5
+ import torch
6
+
7
+ from ...core import Module, Target, Transform
8
+ from ...utils import NumberList, TensorList, generic_eq
9
+
10
+
11
+ def clip_grad_value_(params: Iterable[torch.Tensor], value: float):
12
+ """Clips gradient of an iterable of parameters at specified value.
13
+ Gradients are modified in-place.
14
+ Args:
15
+ params (Iterable[Tensor]): iterable of tensors with gradients to clip.
16
+ value (float or int): maximum allowed value of gradient
17
+ """
18
+ grads = [p.grad for p in params if p.grad is not None]
19
+ torch._foreach_clamp_min_(grads, -value)
20
+ torch._foreach_clamp_max_(grads, value)
21
+
22
+ def _clip_norm_(
23
+ tensors_: TensorList,
24
+ min: float | NumberList | None,
25
+ max: float | NumberList | None,
26
+ norm_value: float | NumberList | None,
27
+ ord: float,
28
+ dim: int | Sequence[int] | Literal["global"] | None,
29
+ inverse_dims: bool,
30
+ min_size: int,
31
+ ) -> TensorList:
32
+ """generic function that can clip norm or normalize"""
33
+ if norm_value is not None:
34
+ if min is not None or max is not None:
35
+ raise ValueError(f'if norm_value is given then min and max must be None got {min = }; {max = }')
36
+
37
+ # if dim is None: return tensors_.mul_(norm_value / tensors_.norm(ord=ord))
38
+ if dim == 'global': return tensors_.mul_(norm_value / tensors_.global_vector_norm(ord=ord))
39
+
40
+ # if dim is None: return tensors_.clip_norm_(min,max,tensorwise=True,ord=ord)
41
+ if dim == 'global': return tensors_.clip_norm_(min,max,tensorwise=False,ord=ord)
42
+
43
+ muls = []
44
+ tensors_to_mul = []
45
+ if isinstance(dim, int): dim = (dim, )
46
+
47
+ for i, tensor in enumerate(tensors_):
48
+ # remove dimensions that overflow tensor.ndim or are too small
49
+ if tensor.ndim == 0: tensor = tensor.unsqueeze(0)
50
+ if dim is None: dim = list(range(tensor.ndim))
51
+ real_dim = [d for d in dim if d < tensor.ndim]
52
+ if inverse_dims: real_dim = [d for d in range(tensor.ndim) if d not in real_dim]
53
+ if len(real_dim) == 0: continue
54
+ size = math.prod(tensor.size(d) for d in real_dim)
55
+ if size < min_size: continue
56
+
57
+ norm: torch.Tensor = torch.linalg.vector_norm(tensor, ord=ord, dim=real_dim, keepdim=True) # pylint:disable=not-callable
58
+ if norm.numel() == 1 and norm == 0: continue
59
+ norm = torch.where(norm == 0, 1, norm)
60
+
61
+ # normalize = True, perform normalization
62
+ norm_v = norm_value[i] if isinstance(norm_value, (list,tuple)) else norm_value
63
+ if norm_v is not None:
64
+ mul = norm_v / norm
65
+
66
+ # else clip to min and max norms
67
+ else:
68
+ minv = min[i] if isinstance(min, (list,tuple)) else min
69
+ maxv = max[i] if isinstance(max, (list,tuple)) else max
70
+
71
+ mul = 1
72
+ if minv is not None:
73
+ mul_to_min = (minv / norm).clamp(min=1)
74
+ mul *= mul_to_min
75
+
76
+ if maxv is not None:
77
+ mul_to_max = (maxv / norm).clamp(max=1)
78
+ mul *= mul_to_max
79
+
80
+ muls.append(mul)
81
+ tensors_to_mul.append(tensor)
82
+
83
+ if len(muls) > 0:
84
+
85
+
86
+ torch._foreach_mul_(tensors_to_mul, muls)
87
+ return tensors_
88
+
89
+
90
+ def clip_grad_norm_(
91
+ params: Iterable[torch.Tensor],
92
+ max_norm: float | None,
93
+ ord: float = 2,
94
+ dim: int | Sequence[int] | Literal["global"] | None = None,
95
+ inverse_dims: bool = False,
96
+ min_size: int = 2,
97
+ min_norm: float | None = None,
98
+ ):
99
+ """Clips gradient of an iterable of parameters to specified norm value.
100
+ Gradients are modified in-place.
101
+
102
+ Args:
103
+ params (Iterable[torch.Tensor]): parameters with gradients to clip.
104
+ value (float): value to clip norm to.
105
+ ord (float, optional): norm order. Defaults to 2.
106
+ dim (int | Sequence[int] | str | None, optional):
107
+ calculates norm along those dimensions.
108
+ If list/tuple, tensors are normalized along all dimensios in `dim` that they have.
109
+ Can be set to "global" to normalize by global norm of all gradients concatenated to a vector.
110
+ Defaults to None.
111
+ min_size (int, optional):
112
+ minimal size of a dimension to normalize along it. Defaults to 1.
113
+ """
114
+ grads = TensorList(p.grad for p in params if p.grad is not None)
115
+ _clip_norm_(grads, min=min_norm, max=max_norm, norm_value=None, ord=ord, dim=dim, inverse_dims=inverse_dims, min_size=min_size)
116
+
117
+
118
+ def normalize_grads_(
119
+ params: Iterable[torch.Tensor],
120
+ norm_value: float,
121
+ ord: float = 2,
122
+ dim: int | Sequence[int] | Literal["global"] | None = None,
123
+ inverse_dims: bool = False,
124
+ min_size: int = 1,
125
+ ):
126
+ """Normalizes gradient of an iterable of parameters to specified norm value.
127
+ Gradients are modified in-place.
128
+
129
+ Args:
130
+ params (Iterable[torch.Tensor]): parameters with gradients to clip.
131
+ norm_value (float): value to clip norm to.
132
+ ord (float, optional): norm order. Defaults to 2.
133
+ dim (int | Sequence[int] | str | None, optional):
134
+ calculates norm along those dimensions.
135
+ If list/tuple, tensors are normalized along all dimensios in `dim` that they have.
136
+ Can be set to "global" to normalize by global norm of all gradients concatenated to a vector.
137
+ Defaults to None.
138
+ inverse_dims (bool, optional):
139
+ if True, the `dims` argument is inverted, and all other dimensions are normalized.
140
+ min_size (int, optional):
141
+ minimal size of a dimension to normalize along it. Defaults to 1.
142
+ """
143
+ grads = TensorList(p.grad for p in params if p.grad is not None)
144
+ _clip_norm_(grads, min=None, max=None, norm_value=norm_value, ord=ord, dim=dim, inverse_dims=inverse_dims, min_size=min_size)
145
+
146
+
147
+ class ClipValue(Transform):
148
+ """Clips update magnitude to be within `(-value, value)` range."""
149
+ def __init__(self, value: float, target: Target = 'update'):
150
+ defaults = dict(value=value)
151
+ super().__init__(defaults, uses_grad=False, target=target)
152
+
153
+ @torch.no_grad
154
+ def transform(self, tensors, params, grads, vars):
155
+ value = self.get_settings('value', params=params)
156
+ return TensorList(tensors).clip_([-v for v in value], value)
157
+
158
+ class ClipNorm(Transform):
159
+ """Clips update norm to be no larger than `value`.
160
+
161
+ Args:
162
+ value (float): value to clip norm to.
163
+ ord (float, optional): norm order. Defaults to 2.
164
+ dim (int | Sequence[int] | str | None, optional):
165
+ calculates norm along those dimensions.
166
+ If list/tuple, tensors are normalized along all dimensios in `dim` that they have.
167
+ Can be set to "global" to normalize by global norm of all gradients concatenated to a vector.
168
+ Defaults to None.
169
+ inverse_dims (bool, optional):
170
+ if True, the `dims` argument is inverted, and all other dimensions are normalized.
171
+ min_size (int, optional):
172
+ minimal numer of elements in a parameter or slice to clip norm. Defaults to 1.
173
+ target (str, optional):
174
+ what this affects.
175
+ """
176
+ def __init__(
177
+ self,
178
+ max_norm: float,
179
+ ord: float = 2,
180
+ dim: int | Sequence[int] | Literal["global"] | None = None,
181
+ inverse_dims: bool = False,
182
+ min_size: int = 1,
183
+ target: Target = "update",
184
+ ):
185
+ defaults = dict(max_norm=max_norm,ord=ord,dim=dim,min_size=min_size,inverse_dims=inverse_dims)
186
+ super().__init__(defaults, uses_grad=False, target=target)
187
+
188
+ @torch.no_grad
189
+ def transform(self, tensors, params, grads, vars):
190
+ max_norm = self.get_settings('max_norm', params=params, cls=NumberList)
191
+ ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(self.settings[params[0]])
192
+ _clip_norm_(
193
+ tensors_ = TensorList(tensors),
194
+ min = 0,
195
+ max = max_norm,
196
+ norm_value = None,
197
+ ord = ord,
198
+ dim = dim,
199
+ inverse_dims=inverse_dims,
200
+ min_size = min_size,
201
+ )
202
+ return tensors
203
+
204
+ class Normalize(Transform):
205
+ """Normalizes the update.
206
+
207
+ Args:
208
+ value (float): desired norm value.
209
+ ord (float, optional): norm order. Defaults to 2.
210
+ dim (int | Sequence[int] | str | None, optional):
211
+ calculates norm along those dimensions.
212
+ If list/tuple, tensors are normalized along all dimensios in `dim` that they have.
213
+ Can be set to "global" to normalize by global norm of all gradients concatenated to a vector.
214
+ Defaults to None.
215
+ inverse_dims (bool, optional):
216
+ if True, the `dims` argument is inverted, and all other dimensions are normalized.
217
+ min_size (int, optional):
218
+ minimal size of a dimension to normalize along it. Defaults to 1.
219
+ target (str, optional):
220
+ what this affects.
221
+ """
222
+ def __init__(
223
+ self,
224
+ norm_value: float = 1,
225
+ ord: float = 2,
226
+ dim: int | Sequence[int] | Literal["global"] | None = None,
227
+ inverse_dims: bool = False,
228
+ min_size: int = 1,
229
+ target: Target = "update",
230
+ ):
231
+ defaults = dict(norm_value=norm_value,ord=ord,dim=dim,min_size=min_size, inverse_dims=inverse_dims)
232
+ super().__init__(defaults, uses_grad=False, target=target)
233
+
234
+ @torch.no_grad
235
+ def transform(self, tensors, params, grads, vars):
236
+ norm_value = self.get_settings('norm_value', params=params, cls=NumberList)
237
+ ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(self.settings[params[0]])
238
+
239
+ _clip_norm_(
240
+ tensors_ = TensorList(tensors),
241
+ min = None,
242
+ max = None,
243
+ norm_value = norm_value,
244
+ ord = ord,
245
+ dim = dim,
246
+ inverse_dims=inverse_dims,
247
+ min_size = min_size,
248
+ )
249
+
250
+ return tensors
251
+
252
+
253
+ def _centralize_(
254
+ tensors_: TensorList,
255
+ dim: int | Sequence[int] | Literal["global"] | None,
256
+ min_size: int,
257
+ inverse_dims: bool,
258
+ ) -> TensorList:
259
+ """generic function that can clip norm or normalize"""
260
+ if dim == 'global': return tensors_.sub_(tensors_.global_mean().item())
261
+
262
+ subs = []
263
+ tensors_to_sub = []
264
+ if isinstance(dim, int): dim = (dim, )
265
+
266
+ for tensor in tensors_:
267
+ # remove dimensions that overflow tensor.ndim or are too small
268
+ if dim is None: dim = list(range(tensor.ndim))
269
+ real_dim = [d for d in dim if d < tensor.ndim]
270
+ if inverse_dims: real_dim = [d for d in range(tensor.ndim) if d not in real_dim]
271
+ if len(real_dim) == 0: continue
272
+ size = math.prod(tensor.size(d) for d in real_dim)
273
+ if size < min_size: continue
274
+
275
+ mean: torch.Tensor = torch.mean(tensor, dim=real_dim, keepdim=True)
276
+ if mean.numel() == 1 and mean == 0: continue
277
+
278
+ subs.append(mean)
279
+ tensors_to_sub.append(tensor)
280
+
281
+ if len(subs) > 0:
282
+ torch._foreach_sub_(tensors_to_sub, subs)
283
+
284
+ return tensors_
285
+
286
+
287
+ class Centralize(Transform):
288
+ """Centralizes the update.
289
+
290
+ Args:
291
+ value (float): desired norm value.
292
+ ord (float, optional): norm order. Defaults to 2.
293
+ dim (int | Sequence[int] | str | None, optional):
294
+ calculates norm along those dimensions.
295
+ If list/tuple, tensors are centralized along all dimensios in `dim` that they have.
296
+ Can be set to "global" to centralize by global mean of all gradients concatenated to a vector.
297
+ Defaults to None.
298
+ inverse_dims (bool, optional):
299
+ if True, the `dims` argument is inverted, and all other dimensions are centralized.
300
+ min_size (int, optional):
301
+ minimal size of a dimension to normalize along it. Defaults to 1.
302
+ """
303
+ def __init__(
304
+ self,
305
+ dim: int | Sequence[int] | Literal["global"] | None = None,
306
+ inverse_dims: bool = False,
307
+ min_size: int = 2,
308
+ target: Target = "update",
309
+ ):
310
+ defaults = dict(dim=dim,min_size=min_size,inverse_dims=inverse_dims)
311
+ super().__init__(defaults, uses_grad=False, target=target)
312
+
313
+ @torch.no_grad
314
+ def transform(self, tensors, params, grads, vars):
315
+ dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(self.settings[params[0]])
316
+
317
+ _centralize_(tensors_ = TensorList(tensors), dim=dim, inverse_dims=inverse_dims, min_size=min_size)
318
+
319
+ return tensors
320
+
@@ -0,0 +1,135 @@
1
+ from operator import itemgetter
2
+ from typing import Literal
3
+ from collections.abc import Iterable, Sequence
4
+
5
+ import torch
6
+
7
+ from ...core import Module, Target, Transform, apply, Chainable
8
+ from ...utils import NumberList, TensorList, generic_eq
9
+
10
+ class ClipNormByEMA(Transform):
11
+ """Clips norm to be no larger than the norm of an exponential moving average of past updates.
12
+
13
+ Args:
14
+ beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
15
+ ord (float, optional): order of the norm. Defaults to 2.
16
+ eps (float, optional): epsilon for division. Defaults to 1e-6.
17
+ tensorwise (bool, optional): whether to calculate norm separately for each layer, or global norm for all layers. Defaults to True.
18
+ max_ema_growth (float | None, optional):
19
+ if specified, exponential moving average norm can grow but at most this value per step. Defaults to 1.5.
20
+ ema_init (str, optional):
21
+ How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
22
+ """
23
+ NORMALIZE = False
24
+ def __init__(
25
+ self,
26
+ beta=0.99,
27
+ ord: float = 2,
28
+ eps=1e-6,
29
+ tensorwise:bool=True,
30
+ max_ema_growth: float | None = 1.5,
31
+ ema_init: Literal['zeros', 'update'] = 'zeros',
32
+ ):
33
+ defaults = dict(beta=beta, ord=ord, tensorwise=tensorwise, ema_init=ema_init, eps=eps, max_ema_growth=max_ema_growth)
34
+ super().__init__(defaults, uses_grad=False)
35
+
36
+ @torch.no_grad
37
+ def transform(self, tensors, params, grads, vars):
38
+ ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(self.settings[params[0]])
39
+
40
+ beta, eps = self.get_settings('beta', 'eps', params=params, cls=NumberList)
41
+ tensors = TensorList(tensors)
42
+
43
+ ema = self.get_state('ema', params=params, init = (torch.zeros_like if ema_init=='zeros' else tensors), cls=TensorList)
44
+ ema.lerp_(tensors, 1-beta)
45
+
46
+ if tensorwise:
47
+ ema_norm = ema.norm(ord)
48
+
49
+ # clip ema norm growth
50
+ if max_ema_growth is not None:
51
+ prev_ema_norm = self.get_state('prev_ema_norm', params=params, init=ema_norm, cls=TensorList)
52
+ allowed_norm = (prev_ema_norm * max_ema_growth).clip(min=1e-6)
53
+ ema_denom = (ema_norm / allowed_norm).clip(min=1)
54
+ ema.div_(ema_denom)
55
+ ema_norm.div_(ema_denom)
56
+ prev_ema_norm.set_(ema_norm)
57
+
58
+ tensors_norm = tensors.norm(ord)
59
+ denom = tensors_norm / ema_norm.clip(min=eps)
60
+ if self.NORMALIZE: denom.clip_(min=eps)
61
+ else: denom.clip_(min=1)
62
+
63
+ else:
64
+ ema_norm = ema.global_vector_norm(ord)
65
+
66
+ # clip ema norm growth
67
+ if max_ema_growth is not None:
68
+ prev_ema_norm = self.global_state.setdefault('prev_ema_norm', ema_norm)
69
+ allowed_norm = prev_ema_norm * max_ema_growth
70
+ if ema_norm > allowed_norm:
71
+ ema.div_(ema_norm / allowed_norm)
72
+ ema_norm = allowed_norm
73
+ prev_ema_norm.set_(ema_norm)
74
+
75
+ tensors_norm = tensors.global_vector_norm(ord)
76
+ denom = tensors_norm / ema_norm.clip(min=eps[0])
77
+ if self.NORMALIZE: denom.clip_(min=eps[0])
78
+ else: denom.clip_(min=1)
79
+
80
+ tensors.div_(denom)
81
+ return tensors
82
+
83
+ class NormalizeByEMA(ClipNormByEMA):
84
+ """Sets norm of the update to be the same as the norm of an exponential moving average of past updates.
85
+
86
+ Args:
87
+ beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
88
+ ord (float, optional): order of the norm. Defaults to 2.
89
+ eps (float, optional): epsilon for division. Defaults to 1e-6.
90
+ tensorwise (bool, optional): whether to calculate norm separately for each layer, or global norm for all layers. Defaults to True.
91
+ max_ema_growth (float | None, optional):
92
+ if specified, exponential moving average norm can grow but at most this value per step. Defaults to 1.5.
93
+ ema_init (str, optional):
94
+ How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
95
+ """
96
+ NORMALIZE = True
97
+
98
+ # TODO Centralize by EMA?
99
+
100
+ class ClipValueByEMA(Transform):
101
+ """Clips magnitude of update to be no larger than magnitude of an exponential moving average of past (unclipped) updates.
102
+
103
+ Args:
104
+ beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
105
+ ema_init (str, optional):
106
+ How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
107
+ ema_tfm (Chainable | None, optional): optional modules applied to exponential moving average before clipping by it. Defaults to None.
108
+ """
109
+ def __init__(
110
+ self,
111
+ beta=0.99,
112
+ ema_init: Literal['zeros', 'update'] = 'zeros',
113
+ ema_tfm:Chainable | None=None,
114
+ ):
115
+ defaults = dict(beta=beta, ema_init=ema_init)
116
+ super().__init__(defaults, uses_grad=False)
117
+
118
+ if ema_tfm is not None:
119
+ self.set_child('ema_tfm', ema_tfm)
120
+
121
+ @torch.no_grad
122
+ def transform(self, tensors, params, grads, vars):
123
+ ema_init = itemgetter('ema_init')(self.settings[params[0]])
124
+
125
+ beta = self.get_settings('beta', params=params, cls=NumberList)
126
+ tensors = TensorList(tensors)
127
+
128
+ ema = self.get_state('ema', params=params, init = (torch.zeros_like if ema_init=='zeros' else lambda t: t.abs()), cls=TensorList)
129
+ ema.lerp_(tensors.abs(), 1-beta)
130
+
131
+ if 'ema_tfm' in self.children:
132
+ ema = TensorList(apply(self.children['ema_tfm'], ema, params, vars.grad, vars))
133
+
134
+ tensors.clip_(-ema, ema)
135
+ return tensors
@@ -0,0 +1,187 @@
1
+ from operator import itemgetter
2
+
3
+ import torch
4
+
5
+ from ...core import TensorwiseTransform, Target, Transform
6
+ from ...utils import TensorList, as_tensorlist
7
+
8
+
9
+ class ClipValueGrowth(TensorwiseTransform):
10
+ """Clips update value magnitude growth.
11
+
12
+ Args:
13
+ add (float | None, optional): additive clipping, next update is at most `previous update + add`. Defaults to None.
14
+ mul (float | None, optional): multiplicative clipping, next update is at most `previous update * mul`. Defaults to 1.5.
15
+ min_value (float | None, optional):
16
+ minimum value for multiplicative clipping to prevent collapse to 0.
17
+ Next update is at most :code:`max(prev_update, min_value) * mul`. Defaults to 1e-4.
18
+ max_decay (float | None, optional):
19
+ bounds the tracked multiplicative clipping decay to prevent collapse to 0.
20
+ Next update is at most :code:`max(previous update * mul, max_decay)`.
21
+ Defaults to 2.
22
+ target (Target, optional): what to set on vars.. Defaults to "update".
23
+ """
24
+ def __init__(
25
+ self,
26
+ add: float | None = None,
27
+ mul: float | None = 1.5,
28
+ min_value: float | None = 1e-4,
29
+ max_decay: float | None = 2,
30
+ target: Target = "update",
31
+ ):
32
+ defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay)
33
+ super().__init__(defaults, uses_grad=False, target=target)
34
+
35
+
36
+ def transform(self, tensor, param, grad, vars):
37
+ add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(self.settings[param])
38
+ add: float | None
39
+
40
+ state = self.state[param]
41
+
42
+ if add is None and mul is None:
43
+ return tensor
44
+
45
+ if 'prev' not in state:
46
+ state['prev'] = tensor.clone()
47
+ return tensor
48
+
49
+ prev: torch.Tensor = state['prev']
50
+
51
+ # additive bound
52
+ if add is not None:
53
+ growth = (tensor.abs() - prev.abs()).clip(min=0)
54
+ tensor.sub_(torch.where(growth > add, (growth-add).copysign_(tensor), 0))
55
+
56
+ # multiplicative bound
57
+ growth = None
58
+ if mul is not None:
59
+ prev_magn = prev.abs()
60
+ if min_value is not None: prev_magn.clip_(min=min_value)
61
+ growth = (tensor.abs() / prev_magn).clamp_(min=1e-8)
62
+
63
+ denom = torch.where(growth > mul, growth/mul, 1)
64
+
65
+ tensor.div_(denom)
66
+
67
+ # limit max growth decay
68
+ if max_decay is not None:
69
+ if growth is None:
70
+ prev_magn = prev.abs()
71
+ if min_value is not None: prev_magn.clip_(min=min_value)
72
+ growth = (tensor.abs() / prev_magn).clamp_(min=1e-8)
73
+
74
+ new_prev = torch.where(growth < (1/max_decay), prev/max_decay, tensor)
75
+ else:
76
+ new_prev = tensor.clone()
77
+
78
+ state['prev'] = new_prev
79
+ return tensor
80
+
81
+
82
+ def norm_growth_clip_(
83
+ tensor_: torch.Tensor,
84
+ prev_norm: torch.Tensor,
85
+ add: float | None,
86
+ mul: float | None,
87
+ min_value: float | None,
88
+ max_decay: float | None,
89
+ ord: float,
90
+ ):
91
+ if add is None and mul is None: return tensor_
92
+ norm = torch.linalg.vector_norm(tensor_, ord=ord) # pylint:disable=not-callable
93
+
94
+ denom = 1
95
+ # additive bound
96
+ if add is not None:
97
+ allowed_norm = prev_norm + add
98
+ if norm > allowed_norm: denom = norm / allowed_norm
99
+
100
+ # multiplicative bound
101
+ if mul is not None:
102
+ allowed_norm = prev_norm * mul
103
+ if norm > allowed_norm: denom = max(denom, norm / allowed_norm)
104
+
105
+ # minimal norm
106
+ if min_value is not None:
107
+ denom = max(denom, min_value)
108
+
109
+ # limit max growth decay
110
+ new_prev_norm = norm/denom
111
+ if max_decay is not None:
112
+ decay = norm / prev_norm
113
+ if decay < (1/max_decay):
114
+ new_prev_norm = prev_norm / max_decay
115
+
116
+ if min_value is not None: new_prev_norm = max(new_prev_norm, min_value) # pyright:ignore[reportArgumentType]
117
+ return tensor_.div_(denom), new_prev_norm, denom
118
+
119
+
120
+ class ClipNormGrowth(Transform):
121
+ """Clips update norm growth.
122
+
123
+ Args:
124
+ add (float | None, optional): additive clipping, next update norm is at most `previous norm + add`. Defaults to None.
125
+ mul (float | None, optional): multiplicative clipping, next update norm is at most `previous norm * mul`. Defaults to 1.5.
126
+ min_value (float | None, optional):
127
+ minimum value for multiplicative clipping to prevent collapse to 0.
128
+ Next norm is at most :code:`max(prev_norm, min_value) * mul`. Defaults to 1e-4.
129
+ max_decay (float | None, optional):
130
+ bounds the tracked multiplicative clipping decay to prevent collapse to 0.
131
+ Next norm is at most :code:`max(previous norm * mul, max_decay)`.
132
+ Defaults to 2.
133
+ ord (float, optional): norm order. Defaults to 2.
134
+ parameterwise (bool, optional):
135
+ if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
136
+ target (Target, optional): what to set on vars. Defaults to "update".
137
+ """
138
+ def __init__(
139
+ self,
140
+ add: float | None = None,
141
+ mul: float | None = 1.5,
142
+ min_value: float | None = 1e-4,
143
+ max_decay: float | None = 2,
144
+ ord: float = 2,
145
+ parameterwise=True,
146
+ target: Target = "update",
147
+ ):
148
+ defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay, ord=ord, parameterwise=parameterwise)
149
+ super().__init__(defaults, uses_grad=False, target=target)
150
+
151
+
152
+
153
+ def transform(self, tensors, params, grads, vars):
154
+ parameterwise = self.settings[params[0]]['parameterwise']
155
+ tensors = TensorList(tensors)
156
+
157
+ if parameterwise:
158
+ ts = tensors
159
+ stts = [self.state[p] for p in params]
160
+ stns = [self.settings[p] for p in params]
161
+
162
+ else:
163
+ ts = [tensors.to_vec()]
164
+ stts = [self.global_state]
165
+ stns = [self.settings[params[0]]]
166
+
167
+
168
+ for t,state, settings in zip(ts, stts, stns):
169
+ if 'prev_norm' not in state:
170
+ state['prev_norm'] = torch.linalg.vector_norm(t, ord=settings['ord']) # pylint:disable=not-callable
171
+ state['prev_denom'] = 1
172
+ continue
173
+
174
+ _, state['prev_norm'], state['prev_denom'] = norm_growth_clip_(
175
+ tensor_ = t,
176
+ prev_norm = state['prev_norm'],
177
+ add = settings['add'],
178
+ mul = settings['mul'],
179
+ min_value = settings['min_value'],
180
+ max_decay = settings['max_decay'],
181
+ ord = settings['ord'],
182
+ )
183
+
184
+ if not parameterwise:
185
+ tensors.from_vec_(ts[0])
186
+
187
+ return tensors
@@ -1,19 +1,14 @@
1
- """Optimizers that I haven't tested and various (mostly stupid) ideas go there.
2
- If something works well I will move it outside of experimental folder.
3
- Otherwise all optimizers in this category should be considered unlikely to good for most tasks."""
4
- from .experimental import GradMin, HVPDiagNewton, MinibatchRprop, ReduceOutwardLR
5
- from .quad_interp import QuadraticInterpolation2Point
6
- from .subspace import (
7
- Proj2Masks,
8
- ProjAscent,
9
- ProjAscentRay,
10
- Projection,
11
- ProjGrad,
12
- ProjGradAscentDifference,
13
- ProjGradRay,
14
- ProjLastAscentDifference,
15
- ProjLastGradDifference,
16
- ProjNormalize,
17
- ProjRandom,
18
- Subspace,
1
+ from .absoap import ABSOAP
2
+ from .adadam import Adadam
3
+ from .adamY import AdamY
4
+ from .adasoap import AdaSOAP
5
+ from .curveball import CurveBall
6
+ from .dsoap import DSOAP
7
+ from .gradmin import GradMin
8
+ from .reduce_outward_lr import ReduceOutwardLR
9
+ from .spectral import SpectralPreconditioner
10
+ from .subspace_preconditioners import (
11
+ HistorySubspacePreconditioning,
12
+ RandomSubspacePreconditioning,
19
13
  )
14
+ from .tropical_newton import TropicalNewton