torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,185 @@
1
+ from typing import Literal
2
+ from collections.abc import Callable
3
+ import torch
4
+
5
+ from ...core import Module, Target, Transform, Chainable, apply_transform
6
+ from ...utils import NumberList, TensorList, as_tensorlist
7
+ def sophia_H(
8
+ tensors: TensorList,
9
+ h: TensorList | None,
10
+ exp_avg_: TensorList,
11
+ h_exp_avg_: TensorList,
12
+ beta1: float | NumberList,
13
+ beta2: float | NumberList,
14
+ update_freq: int,
15
+ precond_scale: float | NumberList,
16
+ clip: float | NumberList,
17
+ eps: float | NumberList,
18
+ step: int
19
+ ):
20
+ # momentum
21
+ exp_avg_.lerp_(tensors, 1-beta1)
22
+
23
+ # update preconditioner
24
+ if step % update_freq == 0:
25
+ assert h is not None
26
+ h_exp_avg_.lerp_(h, 1-beta2)
27
+
28
+ else:
29
+ assert h is None
30
+
31
+ denom = (h_exp_avg_ * precond_scale).clip_(min=eps)
32
+ return (exp_avg_ / denom).clip_(-clip, clip)
33
+
34
+
35
+ class SophiaH(Module):
36
+ """SophiaH optimizer from https://arxiv.org/abs/2305.14342
37
+
38
+ This is similar to Adam, but the second momentum is replaced by an exponential moving average of randomized hessian diagonal estimates, and the update is agressively clipped.
39
+
40
+ .. note::
41
+ In most cases SophiaH should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply SophiaH preconditioning to another module's output.
42
+
43
+ .. note::
44
+ If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
45
+
46
+ .. note::
47
+ This module requires the a closure passed to the optimizer step,
48
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
49
+ The closure must accept a ``backward`` argument (refer to documentation).
50
+
51
+ Args:
52
+ beta1 (float, optional): first momentum. Defaults to 0.96.
53
+ beta2 (float, optional): momentum for hessian diagonal estimate. Defaults to 0.99.
54
+ update_freq (int, optional):
55
+ frequency of updating hessian diagonal estimate via a hessian-vector product. Defaults to 10.
56
+ precond_scale (float, optional):
57
+ scale of the preconditioner. Defaults to 1.
58
+ clip (float, optional):
59
+ clips update to (-clip, clip). Defaults to 1.
60
+ eps (float, optional):
61
+ clips hessian diagonal esimate to be no less than this value. Defaults to 1e-12.
62
+ hvp_method (str, optional):
63
+ Determines how Hessian-vector products are evaluated.
64
+
65
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
66
+ This requires creating a graph for the gradient.
67
+ - ``"forward"``: Use a forward finite difference formula to
68
+ approximate the HVP. This requires one extra gradient evaluation.
69
+ - ``"central"``: Use a central finite difference formula for a
70
+ more accurate HVP approximation. This requires two extra
71
+ gradient evaluations.
72
+ Defaults to "autograd".
73
+ fd_h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
74
+ n_samples (int, optional):
75
+ number of hessian-vector products with random vectors to evaluate each time when updating
76
+ the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
77
+ seed (int | None, optional): seed for random vectors. Defaults to None.
78
+ inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
79
+
80
+ Examples:
81
+ Using SophiaH:
82
+
83
+ .. code-block:: python
84
+
85
+ opt = tz.Modular(
86
+ model.parameters(),
87
+ tz.m.SophiaH(),
88
+ tz.m.LR(0.1)
89
+ )
90
+
91
+ SophiaH preconditioner can be applied to any other module by passing it to the :code:`inner` argument.
92
+ Turn off SophiaH's first momentum to get just the preconditioning. Here is an example of applying
93
+ SophiaH preconditioning to nesterov momentum (:code:`tz.m.NAG`):
94
+
95
+ .. code-block:: python
96
+
97
+ opt = tz.Modular(
98
+ model.parameters(),
99
+ tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
100
+ tz.m.LR(0.1)
101
+ )
102
+
103
+ """
104
+ def __init__(
105
+ self,
106
+ beta1: float = 0.96,
107
+ beta2: float = 0.99,
108
+ update_freq: int = 10,
109
+ precond_scale: float = 1,
110
+ clip: float = 1,
111
+ eps: float = 1e-12,
112
+ hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
113
+ fd_h: float = 1e-3,
114
+ n_samples = 1,
115
+ seed: int | None = None,
116
+ inner: Chainable | None = None
117
+ ):
118
+ defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, precond_scale=precond_scale, clip=clip, eps=eps, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
119
+ super().__init__(defaults)
120
+
121
+ if inner is not None:
122
+ self.set_child('inner', inner)
123
+
124
+ @torch.no_grad
125
+ def step(self, var):
126
+ params = var.params
127
+ settings = self.settings[params[0]]
128
+ hvp_method = settings['hvp_method']
129
+ fd_h = settings['fd_h']
130
+ update_freq = settings['update_freq']
131
+ n_samples = settings['n_samples']
132
+
133
+ seed = settings['seed']
134
+ generator = None
135
+ if seed is not None:
136
+ if 'generator' not in self.global_state:
137
+ self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
138
+ generator = self.global_state['generator']
139
+
140
+ beta1, beta2, precond_scale, clip, eps = self.get_settings(params,
141
+ 'beta1', 'beta2', 'precond_scale', 'clip', 'eps', cls=NumberList)
142
+
143
+ exp_avg, h_exp_avg = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)
144
+
145
+ step = self.global_state.get('step', 0)
146
+ self.global_state['step'] = step + 1
147
+
148
+ closure = var.closure
149
+ assert closure is not None
150
+
151
+ h = None
152
+ if step % update_freq == 0:
153
+
154
+ rgrad=None
155
+ for i in range(n_samples):
156
+ u = [torch.randn(p.shape, device=p.device, dtype=p.dtype, generator=generator) for p in params]
157
+
158
+ Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
159
+ h=fd_h, normalize=True, retain_grad=i < n_samples-1)
160
+ Hvp = tuple(Hvp)
161
+
162
+ if h is None: h = Hvp
163
+ else: torch._foreach_add_(h, Hvp)
164
+
165
+ assert h is not None
166
+ if n_samples > 1: torch._foreach_div_(h, n_samples)
167
+
168
+ update = var.get_update()
169
+ if 'inner' in self.children:
170
+ update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
171
+
172
+ var.update = sophia_H(
173
+ tensors=TensorList(update),
174
+ h=TensorList(h) if h is not None else None,
175
+ exp_avg_=exp_avg,
176
+ h_exp_avg_=h_exp_avg,
177
+ beta1=beta1,
178
+ beta2=beta2,
179
+ update_freq=update_freq,
180
+ precond_scale=precond_scale,
181
+ clip=clip,
182
+ eps=eps,
183
+ step=step,
184
+ )
185
+ return var
@@ -1,11 +1,13 @@
1
+ import math
2
+ from collections.abc import Iterable, Sequence
1
3
  from operator import itemgetter
2
4
  from typing import Literal
3
- from collections.abc import Iterable, Sequence
4
- import math
5
+
5
6
  import torch
6
7
 
7
8
  from ...core import Module, Target, Transform
8
- from ...utils import NumberList, TensorList, generic_eq
9
+ from ...utils import Metrics, NumberList, TensorList
10
+ from ...utils.metrics import _METRICS
9
11
 
10
12
 
11
13
  def clip_grad_value_(params: Iterable[torch.Tensor], value: float):
@@ -24,7 +26,7 @@ def _clip_norm_(
24
26
  min: float | NumberList | None,
25
27
  max: float | NumberList | None,
26
28
  norm_value: float | NumberList | None,
27
- ord: float,
29
+ ord: Metrics,
28
30
  dim: int | Sequence[int] | Literal["global"] | None,
29
31
  inverse_dims: bool,
30
32
  min_size: int,
@@ -35,7 +37,7 @@ def _clip_norm_(
35
37
  raise ValueError(f'if norm_value is given then min and max must be None got {min = }; {max = }')
36
38
 
37
39
  # if dim is None: return tensors_.mul_(norm_value / tensors_.norm(ord=ord))
38
- if dim == 'global': return tensors_.mul_(norm_value / tensors_.global_vector_norm(ord=ord))
40
+ if dim == 'global': return tensors_.mul_(norm_value / tensors_.global_metric(ord))
39
41
 
40
42
  # if dim is None: return tensors_.clip_norm_(min,max,tensorwise=True,ord=ord)
41
43
  if dim == 'global': return tensors_.clip_norm_(min,max,tensorwise=False,ord=ord)
@@ -54,9 +56,13 @@ def _clip_norm_(
54
56
  size = math.prod(tensor.size(d) for d in real_dim)
55
57
  if size < min_size: continue
56
58
 
57
- norm: torch.Tensor = torch.linalg.vector_norm(tensor, ord=ord, dim=real_dim, keepdim=True) # pylint:disable=not-callable
59
+ if isinstance(ord, str):
60
+ norm = _METRICS[ord].evaluate_tensor(tensor, dim=real_dim, keepdim=True)
61
+ else:
62
+ norm: torch.Tensor = torch.linalg.vector_norm(tensor, ord=ord, dim=real_dim, keepdim=True) # pylint:disable=not-callable
63
+
58
64
  if norm.numel() == 1 and norm == 0: continue
59
- norm = torch.where(norm == 0, 1, norm)
65
+ norm = torch.where(norm <= 1e-12, 1, norm)
60
66
 
61
67
  # normalize = True, perform normalization
62
68
  norm_v = norm_value[i] if isinstance(norm_value, (list,tuple)) else norm_value
@@ -90,7 +96,7 @@ def _clip_norm_(
90
96
  def clip_grad_norm_(
91
97
  params: Iterable[torch.Tensor],
92
98
  max_norm: float | None,
93
- ord: float = 2,
99
+ ord: Metrics = 2,
94
100
  dim: int | Sequence[int] | Literal["global"] | None = None,
95
101
  inverse_dims: bool = False,
96
102
  min_size: int = 2,
@@ -101,7 +107,7 @@ def clip_grad_norm_(
101
107
 
102
108
  Args:
103
109
  params (Iterable[torch.Tensor]): parameters with gradients to clip.
104
- value (float): value to clip norm to.
110
+ max_norm (float): value to clip norm to.
105
111
  ord (float, optional): norm order. Defaults to 2.
106
112
  dim (int | Sequence[int] | str | None, optional):
107
113
  calculates norm along those dimensions.
@@ -118,7 +124,7 @@ def clip_grad_norm_(
118
124
  def normalize_grads_(
119
125
  params: Iterable[torch.Tensor],
120
126
  norm_value: float,
121
- ord: float = 2,
127
+ ord: Metrics = 2,
122
128
  dim: int | Sequence[int] | Literal["global"] | None = None,
123
129
  inverse_dims: bool = False,
124
130
  min_size: int = 1,
@@ -145,13 +151,41 @@ def normalize_grads_(
145
151
 
146
152
 
147
153
  class ClipValue(Transform):
148
- """Clips update magnitude to be within `(-value, value)` range."""
154
+ """Clips update magnitude to be within ``(-value, value)`` range.
155
+
156
+ Args:
157
+ value (float): value to clip to.
158
+ target (str): refer to ``target argument`` in documentation.
159
+
160
+ Examples:
161
+
162
+ Gradient clipping:
163
+ ```python
164
+ opt = tz.Modular(
165
+ model.parameters(),
166
+ tz.m.ClipValue(1),
167
+ tz.m.Adam(),
168
+ tz.m.LR(1e-2),
169
+ )
170
+ ```
171
+
172
+ Update clipping:
173
+ ```python
174
+ opt = tz.Modular(
175
+ model.parameters(),
176
+ tz.m.Adam(),
177
+ tz.m.ClipValue(1),
178
+ tz.m.LR(1e-2),
179
+ )
180
+ ```
181
+
182
+ """
149
183
  def __init__(self, value: float, target: Target = 'update'):
150
184
  defaults = dict(value=value)
151
- super().__init__(defaults, uses_grad=False, target=target)
185
+ super().__init__(defaults, target=target)
152
186
 
153
187
  @torch.no_grad
154
- def apply(self, tensors, params, grads, loss, states, settings):
188
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
155
189
  value = [s['value'] for s in settings]
156
190
  return TensorList(tensors).clip_([-v for v in value], value)
157
191
 
@@ -159,7 +193,7 @@ class ClipNorm(Transform):
159
193
  """Clips update norm to be no larger than `value`.
160
194
 
161
195
  Args:
162
- value (float): value to clip norm to.
196
+ max_norm (float): value to clip norm to.
163
197
  ord (float, optional): norm order. Defaults to 2.
164
198
  dim (int | Sequence[int] | str | None, optional):
165
199
  calculates norm along those dimensions.
@@ -172,21 +206,43 @@ class ClipNorm(Transform):
172
206
  minimal numer of elements in a parameter or slice to clip norm. Defaults to 1.
173
207
  target (str, optional):
174
208
  what this affects.
209
+
210
+ Examples:
211
+
212
+ Gradient norm clipping:
213
+ ```python
214
+ opt = tz.Modular(
215
+ model.parameters(),
216
+ tz.m.ClipNorm(1),
217
+ tz.m.Adam(),
218
+ tz.m.LR(1e-2),
219
+ )
220
+ ```
221
+
222
+ Update norm clipping:
223
+ ```python
224
+ opt = tz.Modular(
225
+ model.parameters(),
226
+ tz.m.Adam(),
227
+ tz.m.ClipNorm(1),
228
+ tz.m.LR(1e-2),
229
+ )
230
+ ```
175
231
  """
176
232
  def __init__(
177
233
  self,
178
234
  max_norm: float,
179
- ord: float = 2,
235
+ ord: Metrics = 2,
180
236
  dim: int | Sequence[int] | Literal["global"] | None = None,
181
237
  inverse_dims: bool = False,
182
238
  min_size: int = 1,
183
239
  target: Target = "update",
184
240
  ):
185
241
  defaults = dict(max_norm=max_norm,ord=ord,dim=dim,min_size=min_size,inverse_dims=inverse_dims)
186
- super().__init__(defaults, uses_grad=False, target=target)
242
+ super().__init__(defaults, target=target)
187
243
 
188
244
  @torch.no_grad
189
- def apply(self, tensors, params, grads, loss, states, settings):
245
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
190
246
  max_norm = NumberList(s['max_norm'] for s in settings)
191
247
  ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
192
248
  _clip_norm_(
@@ -205,7 +261,7 @@ class Normalize(Transform):
205
261
  """Normalizes the update.
206
262
 
207
263
  Args:
208
- value (float): desired norm value.
264
+ norm_value (float): desired norm value.
209
265
  ord (float, optional): norm order. Defaults to 2.
210
266
  dim (int | Sequence[int] | str | None, optional):
211
267
  calculates norm along those dimensions.
@@ -218,21 +274,43 @@ class Normalize(Transform):
218
274
  minimal size of a dimension to normalize along it. Defaults to 1.
219
275
  target (str, optional):
220
276
  what this affects.
277
+
278
+ Examples:
279
+ Gradient normalization:
280
+ ```python
281
+ opt = tz.Modular(
282
+ model.parameters(),
283
+ tz.m.Normalize(1),
284
+ tz.m.Adam(),
285
+ tz.m.LR(1e-2),
286
+ )
287
+ ```
288
+
289
+ Update normalization:
290
+
291
+ ```python
292
+ opt = tz.Modular(
293
+ model.parameters(),
294
+ tz.m.Adam(),
295
+ tz.m.Normalize(1),
296
+ tz.m.LR(1e-2),
297
+ )
298
+ ```
221
299
  """
222
300
  def __init__(
223
301
  self,
224
302
  norm_value: float = 1,
225
- ord: float = 2,
303
+ ord: Metrics = 2,
226
304
  dim: int | Sequence[int] | Literal["global"] | None = None,
227
305
  inverse_dims: bool = False,
228
306
  min_size: int = 1,
229
307
  target: Target = "update",
230
308
  ):
231
309
  defaults = dict(norm_value=norm_value,ord=ord,dim=dim,min_size=min_size, inverse_dims=inverse_dims)
232
- super().__init__(defaults, uses_grad=False, target=target)
310
+ super().__init__(defaults, target=target)
233
311
 
234
312
  @torch.no_grad
235
- def apply(self, tensors, params, grads, loss, states, settings):
313
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
236
314
  norm_value = NumberList(s['norm_value'] for s in settings)
237
315
  ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
238
316
 
@@ -288,8 +366,6 @@ class Centralize(Transform):
288
366
  """Centralizes the update.
289
367
 
290
368
  Args:
291
- value (float): desired norm value.
292
- ord (float, optional): norm order. Defaults to 2.
293
369
  dim (int | Sequence[int] | str | None, optional):
294
370
  calculates norm along those dimensions.
295
371
  If list/tuple, tensors are centralized along all dimensios in `dim` that they have.
@@ -299,6 +375,20 @@ class Centralize(Transform):
299
375
  if True, the `dims` argument is inverted, and all other dimensions are centralized.
300
376
  min_size (int, optional):
301
377
  minimal size of a dimension to normalize along it. Defaults to 1.
378
+
379
+ Examples:
380
+
381
+ Standard gradient centralization:
382
+ ```python
383
+ opt = tz.Modular(
384
+ model.parameters(),
385
+ tz.m.Centralize(dim=0),
386
+ tz.m.LR(1e-2),
387
+ )
388
+ ```
389
+
390
+ References:
391
+ - Yong, H., Huang, J., Hua, X., & Zhang, L. (2020). Gradient centralization: A new optimization technique for deep neural networks. In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing. https://arxiv.org/abs/2004.01461
302
392
  """
303
393
  def __init__(
304
394
  self,
@@ -308,10 +398,10 @@ class Centralize(Transform):
308
398
  target: Target = "update",
309
399
  ):
310
400
  defaults = dict(dim=dim,min_size=min_size,inverse_dims=inverse_dims)
311
- super().__init__(defaults, uses_grad=False, target=target)
401
+ super().__init__(defaults, target=target)
312
402
 
313
403
  @torch.no_grad
314
- def apply(self, tensors, params, grads, loss, states, settings):
404
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
315
405
  dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(settings[0])
316
406
 
317
407
  _centralize_(tensors_ = TensorList(tensors), dim=dim, inverse_dims=inverse_dims, min_size=min_size)
@@ -5,7 +5,7 @@ from collections.abc import Iterable, Sequence
5
5
  import torch
6
6
 
7
7
  from ...core import Module, Target, Transform, apply_transform, Chainable
8
- from ...utils import NumberList, TensorList, generic_eq, unpack_dicts, unpack_states
8
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states, Metrics
9
9
 
10
10
  class ClipNormByEMA(Transform):
11
11
  """Clips norm to be no larger than the norm of an exponential moving average of past updates.
@@ -14,9 +14,10 @@ class ClipNormByEMA(Transform):
14
14
  beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
15
15
  ord (float, optional): order of the norm. Defaults to 2.
16
16
  eps (float, optional): epsilon for division. Defaults to 1e-6.
17
- tensorwise (bool, optional): whether to calculate norm separately for each layer, or global norm for all layers. Defaults to True.
17
+ tensorwise (bool, optional):
18
+ if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
18
19
  max_ema_growth (float | None, optional):
19
- if specified, exponential moving average norm can grow but at most this value per step. Defaults to 1.5.
20
+ if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.
20
21
  ema_init (str, optional):
21
22
  How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
22
23
  """
@@ -24,17 +25,18 @@ class ClipNormByEMA(Transform):
24
25
  def __init__(
25
26
  self,
26
27
  beta=0.99,
27
- ord: float = 2,
28
+ ord: Metrics = 2,
28
29
  eps=1e-6,
29
30
  tensorwise:bool=True,
30
31
  max_ema_growth: float | None = 1.5,
31
32
  ema_init: Literal['zeros', 'update'] = 'zeros',
33
+ inner: Chainable | None = None,
32
34
  ):
33
35
  defaults = dict(beta=beta, ord=ord, tensorwise=tensorwise, ema_init=ema_init, eps=eps, max_ema_growth=max_ema_growth)
34
- super().__init__(defaults, uses_grad=False)
36
+ super().__init__(defaults, inner=inner)
35
37
 
36
38
  @torch.no_grad
37
- def apply(self, tensors, params, grads, loss, states, settings):
39
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
38
40
  tensors = TensorList(tensors)
39
41
  ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(settings[0])
40
42
 
@@ -45,7 +47,7 @@ class ClipNormByEMA(Transform):
45
47
  ema.lerp_(tensors, 1-beta)
46
48
 
47
49
  if tensorwise:
48
- ema_norm = ema.norm(ord)
50
+ ema_norm = ema.metric(ord)
49
51
 
50
52
  # clip ema norm growth
51
53
  if max_ema_growth is not None:
@@ -62,7 +64,7 @@ class ClipNormByEMA(Transform):
62
64
  else: denom.clip_(min=1)
63
65
 
64
66
  else:
65
- ema_norm = ema.global_vector_norm(ord)
67
+ ema_norm = ema.global_metric(ord)
66
68
 
67
69
  # clip ema norm growth
68
70
  if max_ema_growth is not None:
@@ -73,12 +75,17 @@ class ClipNormByEMA(Transform):
73
75
  ema_norm = allowed_norm
74
76
  prev_ema_norm.set_(ema_norm)
75
77
 
76
- tensors_norm = tensors.global_vector_norm(ord)
78
+ tensors_norm = tensors.global_metric(ord)
77
79
  denom = tensors_norm / ema_norm.clip(min=eps[0])
78
80
  if self.NORMALIZE: denom.clip_(min=eps[0])
79
81
  else: denom.clip_(min=1)
80
82
 
81
- tensors.div_(denom)
83
+ self.global_state['denom'] = denom
84
+
85
+ @torch.no_grad
86
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
87
+ denom = self.global_state.pop('denom')
88
+ torch._foreach_div_(tensors, denom)
82
89
  return tensors
83
90
 
84
91
  class NormalizeByEMA(ClipNormByEMA):
@@ -88,9 +95,10 @@ class NormalizeByEMA(ClipNormByEMA):
88
95
  beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
89
96
  ord (float, optional): order of the norm. Defaults to 2.
90
97
  eps (float, optional): epsilon for division. Defaults to 1e-6.
91
- tensorwise (bool, optional): whether to calculate norm separately for each layer, or global norm for all layers. Defaults to True.
98
+ tensorwise (bool, optional):
99
+ if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
92
100
  max_ema_growth (float | None, optional):
93
- if specified, exponential moving average norm can grow but at most this value per step. Defaults to 1.5.
101
+ if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.
94
102
  ema_init (str, optional):
95
103
  How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
96
104
  """
@@ -99,28 +107,30 @@ class NormalizeByEMA(ClipNormByEMA):
99
107
  # TODO Centralize by EMA?
100
108
 
101
109
  class ClipValueByEMA(Transform):
102
- """Clips magnitude of update to be no larger than magnitude of an exponential moving average of past (unclipped) updates.
110
+ """Clips magnitude of update to be no larger than magnitude of exponential moving average of past (unclipped) updates.
103
111
 
104
112
  Args:
105
113
  beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
106
114
  ema_init (str, optional):
107
115
  How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
108
- ema_tfm (Chainable | None, optional): optional modules applied to exponential moving average before clipping by it. Defaults to None.
116
+ ema_tfm (Chainable | None, optional):
117
+ optional modules applied to exponential moving average before clipping by it. Defaults to None.
109
118
  """
110
119
  def __init__(
111
120
  self,
112
121
  beta=0.99,
113
122
  ema_init: Literal['zeros', 'update'] = 'zeros',
114
123
  ema_tfm:Chainable | None=None,
124
+ inner: Chainable | None = None,
115
125
  ):
116
126
  defaults = dict(beta=beta, ema_init=ema_init)
117
- super().__init__(defaults, uses_grad=False)
127
+ super().__init__(defaults, inner=inner)
118
128
 
119
129
  if ema_tfm is not None:
120
130
  self.set_child('ema_tfm', ema_tfm)
121
131
 
122
132
  @torch.no_grad
123
- def apply(self, tensors, params, grads, loss, states, settings):
133
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
124
134
  ema_init = itemgetter('ema_init')(settings[0])
125
135
 
126
136
  beta = unpack_dicts(settings, 'beta', cls=NumberList)
@@ -129,8 +139,12 @@ class ClipValueByEMA(Transform):
129
139
  ema = unpack_states(states, tensors, 'ema', init = (torch.zeros_like if ema_init=='zeros' else lambda t: t.abs()), cls=TensorList)
130
140
  ema.lerp_(tensors.abs(), 1-beta)
131
141
 
142
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
143
+ tensors = TensorList(tensors)
144
+ ema = unpack_states(states, tensors, 'ema', cls=TensorList)
145
+
132
146
  if 'ema_tfm' in self.children:
133
- ema = TensorList(apply_transform(self.children['ema_tfm'], ema, params, grads, loss))
147
+ ema = TensorList(apply_transform(self.children['ema_tfm'], ema.clone(), params, grads, loss))
134
148
 
135
149
  tensors.clip_(-ema, ema)
136
150
  return tensors
@@ -19,7 +19,7 @@ class ClipValueGrowth(TensorwiseTransform):
19
19
  bounds the tracked multiplicative clipping decay to prevent collapse to 0.
20
20
  Next update is at most :code:`max(previous update * mul, max_decay)`.
21
21
  Defaults to 2.
22
- target (Target, optional): what to set on var.. Defaults to "update".
22
+ target (Target, optional): what to set on var. Defaults to "update".
23
23
  """
24
24
  def __init__(
25
25
  self,
@@ -30,11 +30,11 @@ class ClipValueGrowth(TensorwiseTransform):
30
30
  target: Target = "update",
31
31
  ):
32
32
  defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay)
33
- super().__init__(defaults, uses_grad=False, target=target)
33
+ super().__init__(defaults, target=target)
34
34
 
35
35
 
36
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
37
- add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(settings)
36
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
37
+ add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(setting)
38
38
  add: float | None
39
39
 
40
40
  if add is None and mul is None:
@@ -120,7 +120,8 @@ class ClipNormGrowth(Transform):
120
120
 
121
121
  Args:
122
122
  add (float | None, optional): additive clipping, next update norm is at most `previous norm + add`. Defaults to None.
123
- mul (float | None, optional): multiplicative clipping, next update norm is at most `previous norm * mul`. Defaults to 1.5.
123
+ mul (float | None, optional):
124
+ multiplicative clipping, next update norm is at most `previous norm * mul`. Defaults to 1.5.
124
125
  min_value (float | None, optional):
125
126
  minimum value for multiplicative clipping to prevent collapse to 0.
126
127
  Next norm is at most :code:`max(prev_norm, min_value) * mul`. Defaults to 1e-4.
@@ -144,11 +145,11 @@ class ClipNormGrowth(Transform):
144
145
  target: Target = "update",
145
146
  ):
146
147
  defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay, ord=ord, parameterwise=parameterwise)
147
- super().__init__(defaults, uses_grad=False, target=target)
148
+ super().__init__(defaults, target=target)
148
149
 
149
150
 
150
151
 
151
- def apply(self, tensors, params, grads, loss, states, settings):
152
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
152
153
  parameterwise = settings[0]['parameterwise']
153
154
  tensors = TensorList(tensors)
154
155
 
@@ -0,0 +1,11 @@
1
+ from .cg import (
2
+ DYHS,
3
+ ConjugateDescent,
4
+ DaiYuan,
5
+ FletcherReeves,
6
+ HagerZhang,
7
+ HestenesStiefel,
8
+ LiuStorey,
9
+ PolakRibiere,
10
+ ProjectedGradientMethod,
11
+ )