torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ import math
5
5
  import torch
6
6
 
7
7
  from ...core import Module, Target, Transform
8
- from ...utils import NumberList, TensorList, generic_eq
8
+ from ...utils import NumberList, TensorList
9
9
 
10
10
 
11
11
  def clip_grad_value_(params: Iterable[torch.Tensor], value: float):
@@ -24,7 +24,7 @@ def _clip_norm_(
24
24
  min: float | NumberList | None,
25
25
  max: float | NumberList | None,
26
26
  norm_value: float | NumberList | None,
27
- ord: float,
27
+ ord: float | Literal['mean_abs'],
28
28
  dim: int | Sequence[int] | Literal["global"] | None,
29
29
  inverse_dims: bool,
30
30
  min_size: int,
@@ -54,9 +54,13 @@ def _clip_norm_(
54
54
  size = math.prod(tensor.size(d) for d in real_dim)
55
55
  if size < min_size: continue
56
56
 
57
- norm: torch.Tensor = torch.linalg.vector_norm(tensor, ord=ord, dim=real_dim, keepdim=True) # pylint:disable=not-callable
57
+ if ord == 'mean_abs':
58
+ norm = tensor.abs().mean(dim=real_dim, keepdim=True)
59
+ else:
60
+ norm: torch.Tensor = torch.linalg.vector_norm(tensor, ord=ord, dim=real_dim, keepdim=True) # pylint:disable=not-callable
61
+
58
62
  if norm.numel() == 1 and norm == 0: continue
59
- norm = torch.where(norm == 0, 1, norm)
63
+ norm = torch.where(norm <= 1e-12, 1, norm)
60
64
 
61
65
  # normalize = True, perform normalization
62
66
  norm_v = norm_value[i] if isinstance(norm_value, (list,tuple)) else norm_value
@@ -90,7 +94,7 @@ def _clip_norm_(
90
94
  def clip_grad_norm_(
91
95
  params: Iterable[torch.Tensor],
92
96
  max_norm: float | None,
93
- ord: float = 2,
97
+ ord: float | Literal['mean_abs'] = 2,
94
98
  dim: int | Sequence[int] | Literal["global"] | None = None,
95
99
  inverse_dims: bool = False,
96
100
  min_size: int = 2,
@@ -118,7 +122,7 @@ def clip_grad_norm_(
118
122
  def normalize_grads_(
119
123
  params: Iterable[torch.Tensor],
120
124
  norm_value: float,
121
- ord: float = 2,
125
+ ord: float | Literal['mean_abs'] = 2,
122
126
  dim: int | Sequence[int] | Literal["global"] | None = None,
123
127
  inverse_dims: bool = False,
124
128
  min_size: int = 1,
@@ -145,14 +149,44 @@ def normalize_grads_(
145
149
 
146
150
 
147
151
  class ClipValue(Transform):
148
- """Clips update magnitude to be within `(-value, value)` range."""
152
+ """Clips update magnitude to be within `(-value, value)` range.
153
+
154
+ Args:
155
+ value (float): value to clip to.
156
+ target (str): refer to :ref:`target argument` in documentation.
157
+
158
+ Examples:
159
+
160
+ Gradient clipping:
161
+
162
+ .. code-block:: python
163
+
164
+ opt = tz.Modular(
165
+ model.parameters(),
166
+ tz.m.ClipValue(1),
167
+ tz.m.Adam(),
168
+ tz.m.LR(1e-2),
169
+ )
170
+
171
+ Update clipping:
172
+
173
+ .. code-block:: python
174
+
175
+ opt = tz.Modular(
176
+ model.parameters(),
177
+ tz.m.Adam(),
178
+ tz.m.ClipValue(1),
179
+ tz.m.LR(1e-2),
180
+ )
181
+
182
+ """
149
183
  def __init__(self, value: float, target: Target = 'update'):
150
184
  defaults = dict(value=value)
151
- super().__init__(defaults, uses_grad=False, target=target)
185
+ super().__init__(defaults, target=target)
152
186
 
153
187
  @torch.no_grad
154
- def transform(self, tensors, params, grads, vars):
155
- value = self.get_settings('value', params=params)
188
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
189
+ value = [s['value'] for s in settings]
156
190
  return TensorList(tensors).clip_([-v for v in value], value)
157
191
 
158
192
  class ClipNorm(Transform):
@@ -172,23 +206,47 @@ class ClipNorm(Transform):
172
206
  minimal numer of elements in a parameter or slice to clip norm. Defaults to 1.
173
207
  target (str, optional):
174
208
  what this affects.
209
+
210
+ Examples:
211
+
212
+ Gradient norm clipping:
213
+
214
+ .. code-block:: python
215
+
216
+ opt = tz.Modular(
217
+ model.parameters(),
218
+ tz.m.ClipNorm(1),
219
+ tz.m.Adam(),
220
+ tz.m.LR(1e-2),
221
+ )
222
+
223
+ Update norm clipping:
224
+
225
+ .. code-block:: python
226
+
227
+ opt = tz.Modular(
228
+ model.parameters(),
229
+ tz.m.Adam(),
230
+ tz.m.ClipNorm(1),
231
+ tz.m.LR(1e-2),
232
+ )
175
233
  """
176
234
  def __init__(
177
235
  self,
178
236
  max_norm: float,
179
- ord: float = 2,
237
+ ord: float | Literal['mean_abs'] = 2,
180
238
  dim: int | Sequence[int] | Literal["global"] | None = None,
181
239
  inverse_dims: bool = False,
182
240
  min_size: int = 1,
183
241
  target: Target = "update",
184
242
  ):
185
243
  defaults = dict(max_norm=max_norm,ord=ord,dim=dim,min_size=min_size,inverse_dims=inverse_dims)
186
- super().__init__(defaults, uses_grad=False, target=target)
244
+ super().__init__(defaults, target=target)
187
245
 
188
246
  @torch.no_grad
189
- def transform(self, tensors, params, grads, vars):
190
- max_norm = self.get_settings('max_norm', params=params, cls=NumberList)
191
- ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(self.settings[params[0]])
247
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
248
+ max_norm = NumberList(s['max_norm'] for s in settings)
249
+ ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
192
250
  _clip_norm_(
193
251
  tensors_ = TensorList(tensors),
194
252
  min = 0,
@@ -218,23 +276,47 @@ class Normalize(Transform):
218
276
  minimal size of a dimension to normalize along it. Defaults to 1.
219
277
  target (str, optional):
220
278
  what this affects.
279
+
280
+ Examples:
281
+
282
+ Gradient normalization:
283
+
284
+ .. code-block:: python
285
+
286
+ opt = tz.Modular(
287
+ model.parameters(),
288
+ tz.m.Normalize(1),
289
+ tz.m.Adam(),
290
+ tz.m.LR(1e-2),
291
+ )
292
+
293
+ Update normalization:
294
+
295
+ .. code-block:: python
296
+
297
+ opt = tz.Modular(
298
+ model.parameters(),
299
+ tz.m.Adam(),
300
+ tz.m.Normalize(1),
301
+ tz.m.LR(1e-2),
302
+ )
221
303
  """
222
304
  def __init__(
223
305
  self,
224
306
  norm_value: float = 1,
225
- ord: float = 2,
307
+ ord: float | Literal['mean_abs'] = 2,
226
308
  dim: int | Sequence[int] | Literal["global"] | None = None,
227
309
  inverse_dims: bool = False,
228
310
  min_size: int = 1,
229
311
  target: Target = "update",
230
312
  ):
231
313
  defaults = dict(norm_value=norm_value,ord=ord,dim=dim,min_size=min_size, inverse_dims=inverse_dims)
232
- super().__init__(defaults, uses_grad=False, target=target)
314
+ super().__init__(defaults, target=target)
233
315
 
234
316
  @torch.no_grad
235
- def transform(self, tensors, params, grads, vars):
236
- norm_value = self.get_settings('norm_value', params=params, cls=NumberList)
237
- ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(self.settings[params[0]])
317
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
318
+ norm_value = NumberList(s['norm_value'] for s in settings)
319
+ ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
238
320
 
239
321
  _clip_norm_(
240
322
  tensors_ = TensorList(tensors),
@@ -299,6 +381,21 @@ class Centralize(Transform):
299
381
  if True, the `dims` argument is inverted, and all other dimensions are centralized.
300
382
  min_size (int, optional):
301
383
  minimal size of a dimension to normalize along it. Defaults to 1.
384
+
385
+ Examples:
386
+
387
+ Standard gradient centralization:
388
+
389
+ .. code-block:: python
390
+
391
+ opt = tz.Modular(
392
+ model.parameters(),
393
+ tz.m.Centralize(dim=0),
394
+ tz.m.LR(1e-2),
395
+ )
396
+
397
+ References:
398
+ - Yong, H., Huang, J., Hua, X., & Zhang, L. (2020). Gradient centralization: A new optimization technique for deep neural networks. In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing. https://arxiv.org/abs/2004.01461
302
399
  """
303
400
  def __init__(
304
401
  self,
@@ -308,11 +405,11 @@ class Centralize(Transform):
308
405
  target: Target = "update",
309
406
  ):
310
407
  defaults = dict(dim=dim,min_size=min_size,inverse_dims=inverse_dims)
311
- super().__init__(defaults, uses_grad=False, target=target)
408
+ super().__init__(defaults, target=target)
312
409
 
313
410
  @torch.no_grad
314
- def transform(self, tensors, params, grads, vars):
315
- dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(self.settings[params[0]])
411
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
412
+ dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(settings[0])
316
413
 
317
414
  _centralize_(tensors_ = TensorList(tensors), dim=dim, inverse_dims=inverse_dims, min_size=min_size)
318
415
 
@@ -4,8 +4,8 @@ from collections.abc import Iterable, Sequence
4
4
 
5
5
  import torch
6
6
 
7
- from ...core import Module, Target, Transform, apply, Chainable
8
- from ...utils import NumberList, TensorList, generic_eq
7
+ from ...core import Module, Target, Transform, apply_transform, Chainable
8
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
9
9
 
10
10
  class ClipNormByEMA(Transform):
11
11
  """Clips norm to be no larger than the norm of an exponential moving average of past updates.
@@ -14,9 +14,10 @@ class ClipNormByEMA(Transform):
14
14
  beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
15
15
  ord (float, optional): order of the norm. Defaults to 2.
16
16
  eps (float, optional): epsilon for division. Defaults to 1e-6.
17
- tensorwise (bool, optional): whether to calculate norm separately for each layer, or global norm for all layers. Defaults to True.
17
+ tensorwise (bool, optional):
18
+ if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
18
19
  max_ema_growth (float | None, optional):
19
- if specified, exponential moving average norm can grow but at most this value per step. Defaults to 1.5.
20
+ if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.
20
21
  ema_init (str, optional):
21
22
  How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
22
23
  """
@@ -29,18 +30,20 @@ class ClipNormByEMA(Transform):
29
30
  tensorwise:bool=True,
30
31
  max_ema_growth: float | None = 1.5,
31
32
  ema_init: Literal['zeros', 'update'] = 'zeros',
33
+ inner: Chainable | None = None,
32
34
  ):
33
35
  defaults = dict(beta=beta, ord=ord, tensorwise=tensorwise, ema_init=ema_init, eps=eps, max_ema_growth=max_ema_growth)
34
- super().__init__(defaults, uses_grad=False)
36
+ super().__init__(defaults, inner=inner)
35
37
 
36
38
  @torch.no_grad
37
- def transform(self, tensors, params, grads, vars):
38
- ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(self.settings[params[0]])
39
-
40
- beta, eps = self.get_settings('beta', 'eps', params=params, cls=NumberList)
39
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
41
40
  tensors = TensorList(tensors)
41
+ ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(settings[0])
42
+
43
+ beta, eps = unpack_dicts(settings, 'beta', 'eps', cls=NumberList)
44
+
45
+ ema = unpack_states(states, tensors, 'ema', init = (torch.zeros_like if ema_init=='zeros' else tensors), cls=TensorList)
42
46
 
43
- ema = self.get_state('ema', params=params, init = (torch.zeros_like if ema_init=='zeros' else tensors), cls=TensorList)
44
47
  ema.lerp_(tensors, 1-beta)
45
48
 
46
49
  if tensorwise:
@@ -48,7 +51,7 @@ class ClipNormByEMA(Transform):
48
51
 
49
52
  # clip ema norm growth
50
53
  if max_ema_growth is not None:
51
- prev_ema_norm = self.get_state('prev_ema_norm', params=params, init=ema_norm, cls=TensorList)
54
+ prev_ema_norm = unpack_states(states, tensors, 'prev_ema_norm', init=ema_norm, cls=TensorList)
52
55
  allowed_norm = (prev_ema_norm * max_ema_growth).clip(min=1e-6)
53
56
  ema_denom = (ema_norm / allowed_norm).clip(min=1)
54
57
  ema.div_(ema_denom)
@@ -77,7 +80,12 @@ class ClipNormByEMA(Transform):
77
80
  if self.NORMALIZE: denom.clip_(min=eps[0])
78
81
  else: denom.clip_(min=1)
79
82
 
80
- tensors.div_(denom)
83
+ self.global_state['denom'] = denom
84
+
85
+ @torch.no_grad
86
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
87
+ denom = self.global_state.pop('denom')
88
+ torch._foreach_div_(tensors, denom)
81
89
  return tensors
82
90
 
83
91
  class NormalizeByEMA(ClipNormByEMA):
@@ -87,9 +95,10 @@ class NormalizeByEMA(ClipNormByEMA):
87
95
  beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
88
96
  ord (float, optional): order of the norm. Defaults to 2.
89
97
  eps (float, optional): epsilon for division. Defaults to 1e-6.
90
- tensorwise (bool, optional): whether to calculate norm separately for each layer, or global norm for all layers. Defaults to True.
98
+ tensorwise (bool, optional):
99
+ if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
91
100
  max_ema_growth (float | None, optional):
92
- if specified, exponential moving average norm can grow but at most this value per step. Defaults to 1.5.
101
+ if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.
93
102
  ema_init (str, optional):
94
103
  How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
95
104
  """
@@ -98,38 +107,44 @@ class NormalizeByEMA(ClipNormByEMA):
98
107
  # TODO Centralize by EMA?
99
108
 
100
109
  class ClipValueByEMA(Transform):
101
- """Clips magnitude of update to be no larger than magnitude of an exponential moving average of past (unclipped) updates.
110
+ """Clips magnitude of update to be no larger than magnitude of exponential moving average of past (unclipped) updates.
102
111
 
103
112
  Args:
104
113
  beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
105
114
  ema_init (str, optional):
106
115
  How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
107
- ema_tfm (Chainable | None, optional): optional modules applied to exponential moving average before clipping by it. Defaults to None.
116
+ ema_tfm (Chainable | None, optional):
117
+ optional modules applied to exponential moving average before clipping by it. Defaults to None.
108
118
  """
109
119
  def __init__(
110
120
  self,
111
121
  beta=0.99,
112
122
  ema_init: Literal['zeros', 'update'] = 'zeros',
113
123
  ema_tfm:Chainable | None=None,
124
+ inner: Chainable | None = None,
114
125
  ):
115
126
  defaults = dict(beta=beta, ema_init=ema_init)
116
- super().__init__(defaults, uses_grad=False)
127
+ super().__init__(defaults, inner=inner)
117
128
 
118
129
  if ema_tfm is not None:
119
130
  self.set_child('ema_tfm', ema_tfm)
120
131
 
121
132
  @torch.no_grad
122
- def transform(self, tensors, params, grads, vars):
123
- ema_init = itemgetter('ema_init')(self.settings[params[0]])
133
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
134
+ ema_init = itemgetter('ema_init')(settings[0])
124
135
 
125
- beta = self.get_settings('beta', params=params, cls=NumberList)
136
+ beta = unpack_dicts(settings, 'beta', cls=NumberList)
126
137
  tensors = TensorList(tensors)
127
138
 
128
- ema = self.get_state('ema', params=params, init = (torch.zeros_like if ema_init=='zeros' else lambda t: t.abs()), cls=TensorList)
139
+ ema = unpack_states(states, tensors, 'ema', init = (torch.zeros_like if ema_init=='zeros' else lambda t: t.abs()), cls=TensorList)
129
140
  ema.lerp_(tensors.abs(), 1-beta)
130
141
 
142
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
143
+ tensors = TensorList(tensors)
144
+ ema = unpack_states(states, tensors, 'ema', cls=TensorList)
145
+
131
146
  if 'ema_tfm' in self.children:
132
- ema = TensorList(apply(self.children['ema_tfm'], ema, params, vars.grad, vars))
147
+ ema = TensorList(apply_transform(self.children['ema_tfm'], ema.clone(), params, grads, loss))
133
148
 
134
149
  tensors.clip_(-ema, ema)
135
150
  return tensors
@@ -19,7 +19,7 @@ class ClipValueGrowth(TensorwiseTransform):
19
19
  bounds the tracked multiplicative clipping decay to prevent collapse to 0.
20
20
  Next update is at most :code:`max(previous update * mul, max_decay)`.
21
21
  Defaults to 2.
22
- target (Target, optional): what to set on vars.. Defaults to "update".
22
+ target (Target, optional): what to set on var. Defaults to "update".
23
23
  """
24
24
  def __init__(
25
25
  self,
@@ -30,15 +30,13 @@ class ClipValueGrowth(TensorwiseTransform):
30
30
  target: Target = "update",
31
31
  ):
32
32
  defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay)
33
- super().__init__(defaults, uses_grad=False, target=target)
33
+ super().__init__(defaults, target=target)
34
34
 
35
35
 
36
- def transform(self, tensor, param, grad, vars):
37
- add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(self.settings[param])
36
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
37
+ add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(setting)
38
38
  add: float | None
39
39
 
40
- state = self.state[param]
41
-
42
40
  if add is None and mul is None:
43
41
  return tensor
44
42
 
@@ -122,7 +120,8 @@ class ClipNormGrowth(Transform):
122
120
 
123
121
  Args:
124
122
  add (float | None, optional): additive clipping, next update norm is at most `previous norm + add`. Defaults to None.
125
- mul (float | None, optional): multiplicative clipping, next update norm is at most `previous norm * mul`. Defaults to 1.5.
123
+ mul (float | None, optional):
124
+ multiplicative clipping, next update norm is at most `previous norm * mul`. Defaults to 1.5.
126
125
  min_value (float | None, optional):
127
126
  minimum value for multiplicative clipping to prevent collapse to 0.
128
127
  Next norm is at most :code:`max(prev_norm, min_value) * mul`. Defaults to 1e-4.
@@ -133,7 +132,7 @@ class ClipNormGrowth(Transform):
133
132
  ord (float, optional): norm order. Defaults to 2.
134
133
  parameterwise (bool, optional):
135
134
  if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
136
- target (Target, optional): what to set on vars. Defaults to "update".
135
+ target (Target, optional): what to set on var. Defaults to "update".
137
136
  """
138
137
  def __init__(
139
138
  self,
@@ -146,39 +145,39 @@ class ClipNormGrowth(Transform):
146
145
  target: Target = "update",
147
146
  ):
148
147
  defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay, ord=ord, parameterwise=parameterwise)
149
- super().__init__(defaults, uses_grad=False, target=target)
148
+ super().__init__(defaults, target=target)
150
149
 
151
150
 
152
151
 
153
- def transform(self, tensors, params, grads, vars):
154
- parameterwise = self.settings[params[0]]['parameterwise']
152
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
153
+ parameterwise = settings[0]['parameterwise']
155
154
  tensors = TensorList(tensors)
156
155
 
157
156
  if parameterwise:
158
157
  ts = tensors
159
- stts = [self.state[p] for p in params]
160
- stns = [self.settings[p] for p in params]
158
+ stts = states
159
+ stns = settings
161
160
 
162
161
  else:
163
162
  ts = [tensors.to_vec()]
164
163
  stts = [self.global_state]
165
- stns = [self.settings[params[0]]]
164
+ stns = [settings[0]]
166
165
 
167
166
 
168
- for t,state, settings in zip(ts, stts, stns):
167
+ for t, state, setting in zip(ts, stts, stns):
169
168
  if 'prev_norm' not in state:
170
- state['prev_norm'] = torch.linalg.vector_norm(t, ord=settings['ord']) # pylint:disable=not-callable
169
+ state['prev_norm'] = torch.linalg.vector_norm(t, ord=setting['ord']) # pylint:disable=not-callable
171
170
  state['prev_denom'] = 1
172
171
  continue
173
172
 
174
173
  _, state['prev_norm'], state['prev_denom'] = norm_growth_clip_(
175
174
  tensor_ = t,
176
175
  prev_norm = state['prev_norm'],
177
- add = settings['add'],
178
- mul = settings['mul'],
179
- min_value = settings['min_value'],
180
- max_decay = settings['max_decay'],
181
- ord = settings['ord'],
176
+ add = setting['add'],
177
+ mul = setting['mul'],
178
+ min_value = setting['min_value'],
179
+ max_decay = setting['max_decay'],
180
+ ord = setting['ord'],
182
181
  )
183
182
 
184
183
  if not parameterwise:
@@ -1,15 +1,41 @@
1
+ """This submodule contains various untested experimental modules, some of them are to be moved out of experimental when properly tested, some are to remain here forever or to be deleted depending on the degree of their usefulness."""
1
2
  from .absoap import ABSOAP
2
3
  from .adadam import Adadam
4
+ from .adam_lambertw import AdamLambertW
3
5
  from .adamY import AdamY
6
+ from .adaptive_step_size import AdaptiveStepSize
4
7
  from .adasoap import AdaSOAP
8
+ from .cosine import (
9
+ AdaptiveDifference,
10
+ AdaptiveDifferenceEMA,
11
+ CosineDebounce,
12
+ CosineMomentum,
13
+ CosineStepSize,
14
+ ScaledAdaptiveDifference,
15
+ )
16
+ from .cubic_adam import CubicAdam
5
17
  from .curveball import CurveBall
6
- from .soapy import SOAPY
18
+
19
+ # from dct import DCTProjection
20
+ from .eigendescent import EigenDescent
21
+ from .etf import (
22
+ ExponentialTrajectoryFit,
23
+ ExponentialTrajectoryFitV2,
24
+ PointwiseExponential,
25
+ )
26
+ from .exp_adam import ExpAdam
27
+ from .expanded_lbfgs import ExpandedLBFGS
28
+ from .fft import FFTProjection
7
29
  from .gradmin import GradMin
30
+ from .hnewton import HNewton
31
+ from .modular_lbfgs import ModularLBFGS
32
+ from .newton_solver import NewtonSolver
33
+ from .newtonnewton import NewtonNewton
34
+ from .parabolic_search import CubicParabolaSearch, ParabolaSearch
8
35
  from .reduce_outward_lr import ReduceOutwardLR
9
- from .spectral import SpectralPreconditioner
36
+ from .structural_projections import BlockPartition, TensorizeProjection
10
37
  from .subspace_preconditioners import (
11
38
  HistorySubspacePreconditioning,
12
39
  RandomSubspacePreconditioning,
13
40
  )
14
- from .tropical_newton import TropicalNewton
15
- from .newton_solver import NewtonSolver
41
+ from .tensor_adagrad import TensorAdagrad