torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,52 +1,19 @@
1
- from typing import Literal
2
- from collections.abc import Callable
3
1
  import torch
4
2
 
5
- from ...core import Module, Target, Transform, Chainable, apply_transform
6
- from ...utils import NumberList, TensorList, as_tensorlist
7
- def sophia_H(
8
- tensors: TensorList,
9
- h: TensorList | None,
10
- exp_avg_: TensorList,
11
- h_exp_avg_: TensorList,
12
- beta1: float | NumberList,
13
- beta2: float | NumberList,
14
- update_freq: int,
15
- precond_scale: float | NumberList,
16
- clip: float | NumberList,
17
- eps: float | NumberList,
18
- step: int
19
- ):
20
- # momentum
21
- exp_avg_.lerp_(tensors, 1-beta1)
22
-
23
- # update preconditioner
24
- if step % update_freq == 0:
25
- assert h is not None
26
- h_exp_avg_.lerp_(h, 1-beta2)
27
-
28
- else:
29
- assert h is None
30
-
31
- denom = (h_exp_avg_ * precond_scale).clip_(min=eps)
32
- return (exp_avg_ / denom).clip_(-clip, clip)
33
-
34
-
35
- class SophiaH(Module):
3
+ from ...core import Chainable, Transform, HVPMethod
4
+ from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
5
+
6
+
7
+
8
+ class SophiaH(Transform):
36
9
  """SophiaH optimizer from https://arxiv.org/abs/2305.14342
37
10
 
38
11
  This is similar to Adam, but the second momentum is replaced by an exponential moving average of randomized hessian diagonal estimates, and the update is agressively clipped.
39
12
 
40
- .. note::
41
- In most cases SophiaH should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply SophiaH preconditioning to another module's output.
13
+ Notes:
14
+ - In most cases SophiaH should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply SophiaH preconditioning to another module's output.
42
15
 
43
- .. note::
44
- If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
45
-
46
- .. note::
47
- This module requires the a closure passed to the optimizer step,
48
- as it needs to re-evaluate the loss and gradients for calculating HVPs.
49
- The closure must accept a ``backward`` argument (refer to documentation).
16
+ - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
50
17
 
51
18
  Args:
52
19
  beta1 (float, optional): first momentum. Defaults to 0.96.
@@ -60,46 +27,48 @@ class SophiaH(Module):
60
27
  eps (float, optional):
61
28
  clips hessian diagonal esimate to be no less than this value. Defaults to 1e-12.
62
29
  hvp_method (str, optional):
63
- Determines how Hessian-vector products are evaluated.
64
-
65
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
66
- This requires creating a graph for the gradient.
67
- - ``"forward"``: Use a forward finite difference formula to
68
- approximate the HVP. This requires one extra gradient evaluation.
69
- - ``"central"``: Use a central finite difference formula for a
70
- more accurate HVP approximation. This requires two extra
71
- gradient evaluations.
72
- Defaults to "autograd".
73
- fd_h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
30
+ Determines how Hessian-vector products are computed.
31
+
32
+ - ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
33
+ - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
34
+ - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
35
+ - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
36
+
37
+ Defaults to ``"autograd"``.
38
+ h (float, optional):
39
+ The step size for finite difference if ``hvp_method`` is
40
+ ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
74
41
  n_samples (int, optional):
75
42
  number of hessian-vector products with random vectors to evaluate each time when updating
76
43
  the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
77
44
  seed (int | None, optional): seed for random vectors. Defaults to None.
78
45
  inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
79
46
 
80
- Examples:
81
- Using SophiaH:
47
+ ### Examples:
82
48
 
83
- .. code-block:: python
49
+ Using SophiaH:
84
50
 
85
- opt = tz.Modular(
86
- model.parameters(),
87
- tz.m.SophiaH(),
88
- tz.m.LR(0.1)
89
- )
51
+ ```python
90
52
 
91
- SophiaH preconditioner can be applied to any other module by passing it to the :code:`inner` argument.
92
- Turn off SophiaH's first momentum to get just the preconditioning. Here is an example of applying
93
- SophiaH preconditioning to nesterov momentum (:code:`tz.m.NAG`):
53
+ opt = tz.Modular(
54
+ model.parameters(),
55
+ tz.m.SophiaH(),
56
+ tz.m.LR(0.1)
57
+ )
58
+ ```
94
59
 
95
- .. code-block:: python
60
+ SophiaH preconditioner can be applied to any other module by passing it to the ``inner`` argument.
61
+ Turn off SophiaH's first momentum to get just the preconditioning. Here is an example of applying
62
+ SophiaH preconditioning to nesterov momentum (``tz.m.NAG``):
96
63
 
97
- opt = tz.Modular(
98
- model.parameters(),
99
- tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
100
- tz.m.LR(0.1)
101
- )
64
+ ```python
102
65
 
66
+ opt = tz.Modular(
67
+ model.parameters(),
68
+ tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
69
+ tz.m.LR(0.1)
70
+ )
71
+ ```
103
72
  """
104
73
  def __init__(
105
74
  self,
@@ -109,77 +78,84 @@ class SophiaH(Module):
109
78
  precond_scale: float = 1,
110
79
  clip: float = 1,
111
80
  eps: float = 1e-12,
112
- hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
113
- fd_h: float = 1e-3,
81
+ hvp_method: HVPMethod = 'autograd',
82
+ distribution: Distributions = 'gaussian',
83
+ h: float = 1e-3,
114
84
  n_samples = 1,
85
+ zHz: bool = True,
86
+ debias: bool = False,
115
87
  seed: int | None = None,
116
- inner: Chainable | None = None
88
+
89
+ exp_avg_tfm: Chainable | None = None,
90
+ D_exp_avg_tfm: Chainable | None = None,
117
91
  ):
118
- defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, precond_scale=precond_scale, clip=clip, eps=eps, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
92
+ defaults = locals().copy()
93
+ del defaults['self'], defaults['exp_avg_tfm'], defaults["D_exp_avg_tfm"]
119
94
  super().__init__(defaults)
120
95
 
121
- if inner is not None:
122
- self.set_child('inner', inner)
96
+ self.set_child('exp_avg', exp_avg_tfm)
97
+ self.set_child('D_exp_avg', D_exp_avg_tfm)
123
98
 
124
99
  @torch.no_grad
125
- def step(self, var):
126
- params = var.params
127
- settings = self.settings[params[0]]
128
- hvp_method = settings['hvp_method']
129
- fd_h = settings['fd_h']
130
- update_freq = settings['update_freq']
131
- n_samples = settings['n_samples']
100
+ def update_states(self, objective, states, settings):
101
+ params = objective.params
132
102
 
133
- seed = settings['seed']
134
- generator = None
135
- if seed is not None:
136
- if 'generator' not in self.global_state:
137
- self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
138
- generator = self.global_state['generator']
103
+ beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
139
104
 
140
- beta1, beta2, precond_scale, clip, eps = self.get_settings(params,
141
- 'beta1', 'beta2', 'precond_scale', 'clip', 'eps', cls=NumberList)
105
+ exp_avg, D_exp_avg = unpack_states(states, params, 'exp_avg', 'D_exp_avg', cls=TensorList)
142
106
 
143
- exp_avg, h_exp_avg = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)
107
+ step = self.increment_counter("step", start=0) # 0 on 1st update
144
108
 
145
- step = self.global_state.get('step', 0)
146
- self.global_state['step'] = step + 1
109
+ # ---------------------------- hutchinson hessian ---------------------------- #
110
+ fs = settings[0]
111
+ update_freq = fs['update_freq']
147
112
 
148
- closure = var.closure
149
- assert closure is not None
150
-
151
- h = None
152
113
  if step % update_freq == 0:
114
+ self.increment_counter("num_Ds", start=1)
115
+
116
+ D, _ = objective.hutchinson_hessian(
117
+ rgrad = None,
118
+ at_x0 = True,
119
+ n_samples = fs['n_samples'],
120
+ distribution = fs['distribution'],
121
+ hvp_method = fs['hvp_method'],
122
+ h = fs['h'],
123
+ zHz = fs["zHz"],
124
+ generator = self.get_generator(params[0].device, fs["seed"]),
125
+ )
126
+
127
+ D_exp_avg.lerp_(D, weight=1-beta2)
128
+
129
+ # --------------------------------- momentum --------------------------------- #
130
+ tensors = objective.get_updates() # do this after hutchinson to not disturb autograd
131
+ exp_avg.lerp_(tensors, 1-beta1)
132
+
133
+
134
+ @torch.no_grad
135
+ def apply_states(self, objective, states, settings):
136
+ params = objective.params
137
+
138
+ beta1, beta2, eps, precond_scale, clip = unpack_dicts(
139
+ settings, 'beta1', 'beta2', 'eps', 'precond_scale', 'clip', cls=NumberList)
140
+
141
+ exp_avg, D_exp_avg = unpack_states(states, params, 'exp_avg', 'D_exp_avg')
142
+
143
+ # ---------------------------------- debias ---------------------------------- #
144
+ if settings[0]["debias"]:
145
+ bias_correction1 = 1.0 - (beta1 ** (self.global_state["step"] + 1))
146
+ bias_correction2 = 1.0 - (beta2 ** self.global_state["num_Ds"])
147
+
148
+ exp_avg = exp_avg / bias_correction1
149
+ D_exp_avg = D_exp_avg / bias_correction2
150
+
151
+ # -------------------------------- transforms -------------------------------- #
152
+ exp_avg = TensorList(self.inner_step_tensors(
153
+ "exp_avg", tensors=exp_avg, clone=True, objective=objective, must_exist=False))
154
+
155
+ D_exp_avg = TensorList(self.inner_step_tensors(
156
+ "D_exp_avg", tensors=D_exp_avg, clone=True, objective=objective, must_exist=False))
153
157
 
154
- rgrad=None
155
- for i in range(n_samples):
156
- u = [torch.randn(p.shape, device=p.device, dtype=p.dtype, generator=generator) for p in params]
157
-
158
- Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
159
- h=fd_h, normalize=True, retain_grad=i < n_samples-1)
160
- Hvp = tuple(Hvp)
161
-
162
- if h is None: h = Hvp
163
- else: torch._foreach_add_(h, Hvp)
164
-
165
- assert h is not None
166
- if n_samples > 1: torch._foreach_div_(h, n_samples)
167
-
168
- update = var.get_update()
169
- if 'inner' in self.children:
170
- update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
171
-
172
- var.update = sophia_H(
173
- tensors=TensorList(update),
174
- h=TensorList(h) if h is not None else None,
175
- exp_avg_=exp_avg,
176
- h_exp_avg_=h_exp_avg,
177
- beta1=beta1,
178
- beta2=beta2,
179
- update_freq=update_freq,
180
- precond_scale=precond_scale,
181
- clip=clip,
182
- eps=eps,
183
- step=step,
184
- )
185
- return var
158
+ # ------------------------------ compute update ------------------------------ #
159
+ denom = D_exp_avg.lazy_mul(precond_scale).clip(min=eps)
160
+ objective.updates = (exp_avg / denom).clip_(-clip, clip)
161
+ return objective
@@ -5,7 +5,7 @@ from typing import Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Module, Target, Transform
8
+ from ...core import Module, TensorTransform
9
9
  from ...utils import Metrics, NumberList, TensorList
10
10
  from ...utils.metrics import _METRICS
11
11
 
@@ -150,7 +150,7 @@ def normalize_grads_(
150
150
  _clip_norm_(grads, min=None, max=None, norm_value=norm_value, ord=ord, dim=dim, inverse_dims=inverse_dims, min_size=min_size)
151
151
 
152
152
 
153
- class ClipValue(Transform):
153
+ class ClipValue(TensorTransform):
154
154
  """Clips update magnitude to be within ``(-value, value)`` range.
155
155
 
156
156
  Args:
@@ -180,17 +180,17 @@ class ClipValue(Transform):
180
180
  ```
181
181
 
182
182
  """
183
- def __init__(self, value: float, target: Target = 'update'):
183
+ def __init__(self, value: float):
184
184
  defaults = dict(value=value)
185
- super().__init__(defaults, target=target)
185
+ super().__init__(defaults)
186
186
 
187
187
  @torch.no_grad
188
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
188
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
189
189
  value = [s['value'] for s in settings]
190
190
  return TensorList(tensors).clip_([-v for v in value], value)
191
191
 
192
- class ClipNorm(Transform):
193
- """Clips update norm to be no larger than `value`.
192
+ class ClipNorm(TensorTransform):
193
+ """Clips update norm to be no larger than ``value``.
194
194
 
195
195
  Args:
196
196
  max_norm (float): value to clip norm to.
@@ -236,13 +236,12 @@ class ClipNorm(Transform):
236
236
  dim: int | Sequence[int] | Literal["global"] | None = None,
237
237
  inverse_dims: bool = False,
238
238
  min_size: int = 1,
239
- target: Target = "update",
240
239
  ):
241
240
  defaults = dict(max_norm=max_norm,ord=ord,dim=dim,min_size=min_size,inverse_dims=inverse_dims)
242
- super().__init__(defaults, target=target)
241
+ super().__init__(defaults)
243
242
 
244
243
  @torch.no_grad
245
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
244
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
246
245
  max_norm = NumberList(s['max_norm'] for s in settings)
247
246
  ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
248
247
  _clip_norm_(
@@ -257,7 +256,7 @@ class ClipNorm(Transform):
257
256
  )
258
257
  return tensors
259
258
 
260
- class Normalize(Transform):
259
+ class Normalize(TensorTransform):
261
260
  """Normalizes the update.
262
261
 
263
262
  Args:
@@ -304,13 +303,12 @@ class Normalize(Transform):
304
303
  dim: int | Sequence[int] | Literal["global"] | None = None,
305
304
  inverse_dims: bool = False,
306
305
  min_size: int = 1,
307
- target: Target = "update",
308
306
  ):
309
307
  defaults = dict(norm_value=norm_value,ord=ord,dim=dim,min_size=min_size, inverse_dims=inverse_dims)
310
- super().__init__(defaults, target=target)
308
+ super().__init__(defaults)
311
309
 
312
310
  @torch.no_grad
313
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
311
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
314
312
  norm_value = NumberList(s['norm_value'] for s in settings)
315
313
  ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
316
314
 
@@ -362,7 +360,7 @@ def _centralize_(
362
360
  return tensors_
363
361
 
364
362
 
365
- class Centralize(Transform):
363
+ class Centralize(TensorTransform):
366
364
  """Centralizes the update.
367
365
 
368
366
  Args:
@@ -395,13 +393,12 @@ class Centralize(Transform):
395
393
  dim: int | Sequence[int] | Literal["global"] | None = None,
396
394
  inverse_dims: bool = False,
397
395
  min_size: int = 2,
398
- target: Target = "update",
399
396
  ):
400
397
  defaults = dict(dim=dim,min_size=min_size,inverse_dims=inverse_dims)
401
- super().__init__(defaults, target=target)
398
+ super().__init__(defaults)
402
399
 
403
400
  @torch.no_grad
404
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
401
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
405
402
  dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(settings[0])
406
403
 
407
404
  _centralize_(tensors_ = TensorList(tensors), dim=dim, inverse_dims=inverse_dims, min_size=min_size)
@@ -1,13 +1,14 @@
1
+ from collections.abc import Iterable, Sequence
1
2
  from operator import itemgetter
2
3
  from typing import Literal
3
- from collections.abc import Iterable, Sequence
4
4
 
5
5
  import torch
6
6
 
7
- from ...core import Module, Target, Transform, apply_transform, Chainable
8
- from ...utils import NumberList, TensorList, unpack_dicts, unpack_states, Metrics
7
+ from ...core import Chainable, TensorTransform, step
8
+ from ...utils import Metrics, NumberList, TensorList, unpack_dicts, unpack_states
9
+
9
10
 
10
- class ClipNormByEMA(Transform):
11
+ class ClipNormByEMA(TensorTransform):
11
12
  """Clips norm to be no larger than the norm of an exponential moving average of past updates.
12
13
 
13
14
  Args:
@@ -36,7 +37,7 @@ class ClipNormByEMA(Transform):
36
37
  super().__init__(defaults, inner=inner)
37
38
 
38
39
  @torch.no_grad
39
- def update_tensors(self, tensors, params, grads, loss, states, settings):
40
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
40
41
  tensors = TensorList(tensors)
41
42
  ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(settings[0])
42
43
 
@@ -83,7 +84,7 @@ class ClipNormByEMA(Transform):
83
84
  self.global_state['denom'] = denom
84
85
 
85
86
  @torch.no_grad
86
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
87
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
87
88
  denom = self.global_state.pop('denom')
88
89
  torch._foreach_div_(tensors, denom)
89
90
  return tensors
@@ -106,45 +107,50 @@ class NormalizeByEMA(ClipNormByEMA):
106
107
 
107
108
  # TODO Centralize by EMA?
108
109
 
109
- class ClipValueByEMA(Transform):
110
+ class ClipValueByEMA(TensorTransform):
110
111
  """Clips magnitude of update to be no larger than magnitude of exponential moving average of past (unclipped) updates.
111
112
 
112
113
  Args:
113
114
  beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
114
115
  ema_init (str, optional):
115
- How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
116
- ema_tfm (Chainable | None, optional):
116
+ How to initialize exponential moving average on first step,
117
+ "update" to use the first update or "zeros". Defaults to 'zeros'.
118
+ exp_avg_tfm (Chainable | None, optional):
117
119
  optional modules applied to exponential moving average before clipping by it. Defaults to None.
118
120
  """
119
121
  def __init__(
120
122
  self,
121
123
  beta=0.99,
122
- ema_init: Literal['zeros', 'update'] = 'zeros',
123
- ema_tfm:Chainable | None=None,
124
+ init: Literal['zeros', 'update'] = 'zeros',
125
+
124
126
  inner: Chainable | None = None,
127
+ exp_avg_tfm:Chainable | None=None,
125
128
  ):
126
- defaults = dict(beta=beta, ema_init=ema_init)
129
+ defaults = dict(beta=beta, init=init)
127
130
  super().__init__(defaults, inner=inner)
128
131
 
129
- if ema_tfm is not None:
130
- self.set_child('ema_tfm', ema_tfm)
132
+ self.set_child('exp_avg', exp_avg_tfm)
131
133
 
132
- @torch.no_grad
133
- def update_tensors(self, tensors, params, grads, loss, states, settings):
134
- ema_init = itemgetter('ema_init')(settings[0])
134
+ def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
135
+ if setting["init"] == "zeros":
136
+ state["exp_avg"] = torch.zeros_like(tensor)
137
+ else:
138
+ state["exp_avg"] = tensor.abs()
135
139
 
136
- beta = unpack_dicts(settings, 'beta', cls=NumberList)
140
+ @torch.no_grad
141
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
137
142
  tensors = TensorList(tensors)
143
+ beta = unpack_dicts(settings, 'beta', cls=NumberList)
138
144
 
139
- ema = unpack_states(states, tensors, 'ema', init = (torch.zeros_like if ema_init=='zeros' else lambda t: t.abs()), cls=TensorList)
140
- ema.lerp_(tensors.abs(), 1-beta)
145
+ exp_avg = unpack_states(states, tensors, 'exp_avg', must_exist=True, cls=TensorList)
146
+ exp_avg.lerp_(tensors.abs(), 1-beta)
141
147
 
142
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
148
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
143
149
  tensors = TensorList(tensors)
144
- ema = unpack_states(states, tensors, 'ema', cls=TensorList)
150
+ exp_avg = unpack_states(states, tensors, 'exp_avg')
145
151
 
146
- if 'ema_tfm' in self.children:
147
- ema = TensorList(apply_transform(self.children['ema_tfm'], ema.clone(), params, grads, loss))
152
+ exp_avg = TensorList(
153
+ self.inner_step_tensors("exp_avg", exp_avg, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
148
154
 
149
- tensors.clip_(-ema, ema)
155
+ tensors.clip_(-exp_avg, exp_avg)
150
156
  return tensors
@@ -2,11 +2,11 @@ from operator import itemgetter
2
2
 
3
3
  import torch
4
4
 
5
- from ...core import TensorwiseTransform, Target, Transform
6
- from ...utils import TensorList, as_tensorlist
5
+ from ...core import TensorTransform
6
+ from ...utils import TensorList
7
7
 
8
8
 
9
- class ClipValueGrowth(TensorwiseTransform):
9
+ class ClipValueGrowth(TensorTransform):
10
10
  """Clips update value magnitude growth.
11
11
 
12
12
  Args:
@@ -27,13 +27,12 @@ class ClipValueGrowth(TensorwiseTransform):
27
27
  mul: float | None = 1.5,
28
28
  min_value: float | None = 1e-4,
29
29
  max_decay: float | None = 2,
30
- target: Target = "update",
31
30
  ):
32
31
  defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay)
33
- super().__init__(defaults, target=target)
32
+ super().__init__(defaults)
34
33
 
35
34
 
36
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
35
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
37
36
  add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(setting)
38
37
  add: float | None
39
38
 
@@ -115,7 +114,7 @@ def norm_growth_clip_(
115
114
  return tensor_.div_(denom), new_prev_norm, denom
116
115
 
117
116
 
118
- class ClipNormGrowth(Transform):
117
+ class ClipNormGrowth(TensorTransform):
119
118
  """Clips update norm growth.
120
119
 
121
120
  Args:
@@ -130,7 +129,7 @@ class ClipNormGrowth(Transform):
130
129
  Next norm is at most :code:`max(previous norm * mul, max_decay)`.
131
130
  Defaults to 2.
132
131
  ord (float, optional): norm order. Defaults to 2.
133
- parameterwise (bool, optional):
132
+ tensorwise (bool, optional):
134
133
  if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
135
134
  target (Target, optional): what to set on var. Defaults to "update".
136
135
  """
@@ -141,19 +140,17 @@ class ClipNormGrowth(Transform):
141
140
  min_value: float | None = 1e-4,
142
141
  max_decay: float | None = 2,
143
142
  ord: float = 2,
144
- parameterwise=True,
145
- target: Target = "update",
143
+ tensorwise=True,
146
144
  ):
147
- defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay, ord=ord, parameterwise=parameterwise)
148
- super().__init__(defaults, target=target)
145
+ defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay, ord=ord, tensorwise=tensorwise)
146
+ super().__init__(defaults)
149
147
 
150
148
 
151
-
152
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
153
- parameterwise = settings[0]['parameterwise']
149
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
150
+ tensorwise = settings[0]['tensorwise']
154
151
  tensors = TensorList(tensors)
155
152
 
156
- if parameterwise:
153
+ if tensorwise:
157
154
  ts = tensors
158
155
  stts = states
159
156
  stns = settings
@@ -180,7 +177,7 @@ class ClipNormGrowth(Transform):
180
177
  ord = setting['ord'],
181
178
  )
182
179
 
183
- if not parameterwise:
180
+ if not tensorwise:
184
181
  tensors.from_vec_(ts[0])
185
182
 
186
183
  return tensors