torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,12 @@
1
1
  import torch
2
- from ...core import Module, Chainable, apply_transform
2
+ from ...core import Transform
3
3
 
4
4
  from ...utils.derivatives import jacobian_wrt, flatten_jacobian
5
- from ...utils import vec_to_tensors, TensorList
6
- from ...utils.linalg import linear_operator
5
+ from ...utils import vec_to_tensors
6
+ from ...linalg import linear_operator
7
7
  from .lmadagrad import lm_adagrad_apply, lm_adagrad_update
8
8
 
9
- class NaturalGradient(Module):
9
+ class NaturalGradient(Transform):
10
10
  """Natural gradient approximated via empirical fisher information matrix.
11
11
 
12
12
  To use this, either pass vector of per-sample losses to the step method, or make sure
@@ -27,9 +27,9 @@ class NaturalGradient(Module):
27
27
  with a vector that isn't strictly per-sample gradients, but rather for example different losses.
28
28
  gn_grad (bool, optional):
29
29
  if True, uses Gauss-Newton G^T @ f as the gradient, which is effectively sum weighted by value
30
- and is equivalent to squaring the values. This way you can solve least-squares
31
- objectives with a NGD-like algorithm. If False, uses sum of per-sample gradients.
32
- This has an effect when ``sqrt=True``, and affects the ``grad`` attribute.
30
+ and is equivalent to squaring the values. That makes the kernel trick solver incorrect, but for
31
+ some reason it still works. If False, uses sum of per-sample gradients.
32
+ This has an effect when ``sqrt=False``, and affects the ``grad`` attribute.
33
33
  Defaults to False.
34
34
  batched (bool, optional): whether to use vmapping. Defaults to True.
35
35
 
@@ -97,20 +97,21 @@ class NaturalGradient(Module):
97
97
  super().__init__(defaults=dict(batched=batched, reg=reg, sqrt=sqrt, gn_grad=gn_grad))
98
98
 
99
99
  @torch.no_grad
100
- def update(self, var):
101
- params = var.params
102
- batched = self.defaults['batched']
103
- gn_grad = self.defaults['gn_grad']
100
+ def update_states(self, objective, states, settings):
101
+ params = objective.params
102
+ fs = settings[0]
103
+ batched = fs['batched']
104
+ gn_grad = fs['gn_grad']
104
105
 
105
- closure = var.closure
106
+ closure = objective.closure
106
107
  assert closure is not None
107
108
 
108
109
  with torch.enable_grad():
109
- f = var.get_loss(backward=False) # n_out
110
+ f = objective.get_loss(backward=False) # n_out
110
111
  assert isinstance(f, torch.Tensor)
111
112
  G_list = jacobian_wrt([f.ravel()], params, batched=batched)
112
113
 
113
- var.loss = f.sum()
114
+ objective.loss = f.sum()
114
115
  G = self.global_state["G"] = flatten_jacobian(G_list) # (n_samples, ndim)
115
116
 
116
117
  if gn_grad:
@@ -119,13 +120,13 @@ class NaturalGradient(Module):
119
120
  else:
120
121
  g = self.global_state["g"] = G.sum(0)
121
122
 
122
- var.grad = vec_to_tensors(g, params)
123
+ objective.grads = vec_to_tensors(g, params)
123
124
 
124
125
  # set closure to calculate scalar value for line searches etc
125
- if var.closure is not None:
126
+ if objective.closure is not None:
126
127
  def ngd_closure(backward=True):
127
128
  if backward:
128
- var.zero_grad()
129
+ objective.zero_grad()
129
130
  with torch.enable_grad():
130
131
  loss = closure(False)
131
132
  if gn_grad: loss = loss.pow(2)
@@ -137,13 +138,14 @@ class NaturalGradient(Module):
137
138
  if gn_grad: loss = loss.pow(2)
138
139
  return loss.sum()
139
140
 
140
- var.closure = ngd_closure
141
+ objective.closure = ngd_closure
141
142
 
142
143
  @torch.no_grad
143
- def apply(self, var):
144
- params = var.params
145
- reg = self.defaults['reg']
146
- sqrt = self.defaults['sqrt']
144
+ def apply_states(self, objective, states, settings):
145
+ params = objective.params
146
+ fs = settings[0]
147
+ reg = fs['reg']
148
+ sqrt = fs['sqrt']
147
149
 
148
150
  G: torch.Tensor = self.global_state['G'] # (n_samples, n_dim)
149
151
 
@@ -151,12 +153,15 @@ class NaturalGradient(Module):
151
153
  # this computes U, S <- SVD(M), then calculate update as U S^-1 Uᵀg,
152
154
  # but it computes it through eigendecompotision
153
155
  U, L = lm_adagrad_update(G.H, reg, 0)
154
- if U is None or L is None: return var
156
+ if U is None or L is None: return objective
155
157
 
156
158
  v = lm_adagrad_apply(self.global_state["g"], U, L)
157
- var.update = vec_to_tensors(v, params)
158
- return var
159
+ objective.updates = vec_to_tensors(v, params)
160
+ return objective
159
161
 
162
+ # we need (G^T G)v = g
163
+ # where g = G^T
164
+ # so we need to solve (G^T G)v = G^T
160
165
  GGT = G @ G.H # (n_samples, n_samples)
161
166
 
162
167
  if reg != 0:
@@ -165,11 +170,11 @@ class NaturalGradient(Module):
165
170
  z, _ = torch.linalg.solve_ex(GGT, torch.ones_like(GGT[0])) # pylint:disable=not-callable
166
171
  v = G.H @ z
167
172
 
168
- var.update = vec_to_tensors(v, params)
169
- return var
173
+ objective.updates = vec_to_tensors(v, params)
174
+ return objective
170
175
 
171
176
 
172
- def get_H(self, var):
177
+ def get_H(self, objective=...):
173
178
  if "G" not in self.global_state: return linear_operator.ScaledIdentity()
174
179
  G = self.global_state['G']
175
180
  return linear_operator.AtA(G)
@@ -1,13 +1,9 @@
1
- from operator import itemgetter
2
- import math
3
- import warnings
4
- from collections.abc import Iterable, Sequence
5
- from typing import Literal
1
+ from collections.abc import Iterable
6
2
 
7
3
  import torch
8
4
 
9
- from ...core import Target, Transform
10
- from ...utils import as_tensorlist
5
+ from ...core import TensorTransform
6
+ from ...utils import TensorList
11
7
 
12
8
  def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
13
9
  """Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
@@ -19,29 +15,29 @@ def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
19
15
  reference
20
16
  https://arxiv.org/abs/2501.04697
21
17
  """
22
- params = as_tensorlist(params).with_grad()
18
+ params = TensorList(params).with_grad()
23
19
  grad = params.grad
24
20
  grad -= (params.dot(grad)/(params.dot(params) + eps)) * params
25
21
 
26
22
 
27
- class OrthoGrad(Transform):
23
+ class OrthoGrad(TensorTransform):
28
24
  """Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
29
25
 
30
26
  Args:
31
27
  eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
32
28
  renormalize (bool, optional): whether to graft projected gradient to original gradient norm. Defaults to True.
33
- target (Target, optional): what to set on var. Defaults to 'update'.
34
29
  """
35
- def __init__(self, eps: float = 1e-8, renormalize=True, target: Target = 'update'):
30
+ def __init__(self, eps: float = 1e-8, renormalize=True):
36
31
  defaults = dict(eps=eps, renormalize=renormalize)
37
- super().__init__(defaults, uses_grad=False, target=target)
32
+ super().__init__(defaults)
38
33
 
39
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
34
+ @torch.no_grad
35
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
40
36
  eps = settings[0]['eps']
41
37
  renormalize = settings[0]['renormalize']
42
38
 
43
- params = as_tensorlist(params)
44
- target = as_tensorlist(tensors)
39
+ params = TensorList(params)
40
+ target = TensorList(tensors)
45
41
 
46
42
  scale = params.dot(target)/(params.dot(params) + eps)
47
43
  if renormalize:
@@ -1,45 +1,11 @@
1
- from operator import itemgetter
2
1
  from typing import Literal
3
2
 
4
3
  import torch
5
4
 
6
- from ...core import Module, Target, Transform, Chainable, Var, apply_transform
5
+ from ...core import TensorTransform, Chainable
7
6
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
8
- from ..functional import sqrt_centered_ema_sq_, sqrt_ema_sq_
9
-
10
-
11
- def rmsprop_(
12
- tensors_: TensorList,
13
- exp_avg_sq_: TensorList,
14
- smoothing: float | NumberList,
15
- eps: float | NumberList,
16
- debiased: bool,
17
- step: int,
18
- exp_avg_: TensorList | None = None,
19
- max_exp_avg_sq_: TensorList | None = None,
20
- pow: float = 2,
21
-
22
- # inner args
23
- inner: Module | None = None,
24
- params: list[torch.Tensor] | None = None,
25
- grads: list[torch.Tensor] | None = None,
26
- ):
27
- """returns `tensors_`"""
28
- if exp_avg_ is not None:
29
- sqrt_exp_avg_sq = sqrt_centered_ema_sq_(tensors=tensors_, exp_avg_=exp_avg_,
30
- exp_avg_sq_=exp_avg_sq_,max_exp_avg_sq_=max_exp_avg_sq_,
31
- beta=smoothing,debiased=debiased,step=step,pow=pow)
32
- else:
33
- sqrt_exp_avg_sq = sqrt_ema_sq_(tensors=tensors_,exp_avg_sq_=exp_avg_sq_,max_exp_avg_sq_=max_exp_avg_sq_,
34
- beta=smoothing,debiased=debiased,step=step,pow=pow)
35
-
36
- if inner is not None:
37
- assert params is not None
38
- tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
39
-
40
- return tensors_.div_(sqrt_exp_avg_sq.add_(eps))
41
-
42
- class RMSprop(Transform):
7
+
8
+ class RMSprop(TensorTransform):
43
9
  """Divides graient by EMA of gradient squares.
44
10
 
45
11
  This implementation is identical to :code:`torch.optim.RMSprop`.
@@ -48,7 +14,7 @@ class RMSprop(Transform):
48
14
  smoothing (float, optional): beta for exponential moving average of gradient squares. Defaults to 0.99.
49
15
  eps (float, optional): epsilon for division. Defaults to 1e-8.
50
16
  centered (bool, optional): whether to center EMA of gradient squares using an additional EMA. Defaults to False.
51
- debiased (bool, optional): applies Adam debiasing. Defaults to False.
17
+ debias (bool, optional): applies Adam debiasing. Defaults to False.
52
18
  amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
53
19
  pow (float, optional): power used in second momentum power and root. Defaults to 2.
54
20
  init (str, optional): how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".
@@ -60,44 +26,86 @@ class RMSprop(Transform):
60
26
  smoothing: float = 0.99,
61
27
  eps: float = 1e-8,
62
28
  centered: bool = False,
63
- debiased: bool = False,
29
+ debias: bool = False,
64
30
  amsgrad: bool = False,
65
- pow: float = 2,
66
31
  init: Literal["zeros", "update"] = "zeros",
32
+
67
33
  inner: Chainable | None = None,
34
+ exp_avg_sq_tfm: Chainable | None = None,
68
35
  ):
69
- defaults = dict(smoothing=smoothing,eps=eps,centered=centered,debiased=debiased,amsgrad=amsgrad,pow=pow,init=init)
70
- super().__init__(defaults=defaults, uses_grad=False)
71
-
72
- if inner is not None:
73
- self.set_child('inner', inner)
74
-
75
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
76
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
77
- smoothing, eps = unpack_dicts(settings, 'smoothing', 'eps', cls=NumberList)
78
- centered, debiased, amsgrad, pow, init = itemgetter('centered','debiased','amsgrad','pow','init')(settings[0])
79
-
80
- exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
81
- exp_avg = unpack_states(states, tensors, 'exp_avg', cls=TensorList) if centered else None
82
- max_exp_avg_sq = unpack_states(states, tensors, 'max_exp_avg_sq', cls=TensorList) if amsgrad else None
83
-
84
- if init == 'update' and step == 1:
85
- exp_avg_sq.set_([t**2 for t in tensors])
86
- if exp_avg is not None: exp_avg.set_([t.clone() for t in tensors])
87
-
88
- return rmsprop_(
89
- TensorList(tensors),
90
- exp_avg_sq_=exp_avg_sq,
91
- smoothing=smoothing,
92
- eps=eps,
93
- debiased=debiased,
94
- step=step,
95
- exp_avg_=exp_avg,
96
- max_exp_avg_sq_=max_exp_avg_sq,
97
- pow=pow,
98
-
99
- # inner args
100
- inner=self.children.get("inner", None),
101
- params=params,
102
- grads=grads,
103
- )
36
+ defaults = locals().copy()
37
+ del defaults['self'], defaults["inner"], defaults["exp_avg_sq_tfm"]
38
+ super().__init__(defaults, inner=inner)
39
+
40
+ self.set_child('exp_avg_sq', exp_avg_sq_tfm)
41
+
42
+ @torch.no_grad
43
+ def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
44
+ if setting["init"] == "zeros":
45
+ state["exp_avg_sq"] = torch.zeros_like(tensor)
46
+ if setting["centered"]: state["exp_avg"] = torch.zeros_like(tensor)
47
+ if setting["amsgrad"]: state["amsgrad"] = torch.zeros_like(tensor)
48
+
49
+ else:
50
+ state["exp_avg_sq"] = tensor ** 2
51
+ if setting["centered"]: state["exp_avg"] = tensor.clone()
52
+ if setting["amsgrad"]: state["amsgrad"] = tensor ** 2
53
+
54
+ @torch.no_grad
55
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
56
+ self.increment_counter("step", start = 0)
57
+ fs = settings[0]
58
+
59
+ exp_avg_sq = unpack_states(states, tensors, "exp_avg_sq", cls=TensorList)
60
+
61
+ # update exponential average
62
+ smoothing = NumberList(s["smoothing"] for s in settings)
63
+ exp_avg_sq.mul_(smoothing).addcmul_(tensors, tensors, value=1-smoothing)
64
+
65
+ # update mean estimate if centered
66
+ if fs["centered"]:
67
+ exp_avg = unpack_states(states, tensors, "exp_avg", cls=TensorList)
68
+ exp_avg.lerp_(tensors, 1-smoothing)
69
+
70
+ # amsgrad
71
+ if fs["amsgrad"]:
72
+ exp_avg_sq_max = unpack_states(states, tensors, "exp_avg_sq_max", cls=TensorList)
73
+ exp_avg_sq_max.maximum_(exp_avg_sq)
74
+
75
+ @torch.no_grad
76
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
77
+ tensors = TensorList(tensors)
78
+ step = self.global_state["step"] # 0 on 1st step
79
+ eps = NumberList(s["eps"] for s in settings)
80
+ fs = settings[0]
81
+
82
+ if fs["amsgrad"]: key = "max_exp_avg_sq"
83
+ else: key = "exp_avg_sq"
84
+ exp_avg_sq = TensorList(s[key] for s in states)
85
+
86
+ # load mean estimate if centered
87
+ exp_avg = None
88
+ if fs['centered']:
89
+ exp_avg = TensorList(s["exp_avg"] for s in states)
90
+
91
+ # debias exp_avg_sq and exp_avg
92
+ if fs["debias"]:
93
+ smoothing = NumberList(s["smoothing"] for s in settings)
94
+ bias_correction = 1 - (smoothing ** (step + 1))
95
+ exp_avg_sq = exp_avg_sq / bias_correction
96
+
97
+ if fs['centered']:
98
+ assert exp_avg is not None
99
+ exp_avg = exp_avg / bias_correction
100
+
101
+ # apply transform to potentially debiased exp_avg_sq
102
+ exp_avg_sq = TensorList(self.inner_step_tensors(
103
+ "exp_avg_sq", exp_avg_sq, params=params, grads=grads, loss=loss, clone=True, must_exist=False
104
+ ))
105
+
106
+ # center
107
+ if fs["centered"]:
108
+ assert exp_avg is not None
109
+ exp_avg_sq = exp_avg_sq.addcmul(exp_avg, exp_avg, value=-1)
110
+
111
+ return tensors.div_(exp_avg_sq.sqrt().add_(eps))
@@ -1,8 +1,8 @@
1
1
 
2
2
  import torch
3
3
 
4
- from ...core import Module, Target, Transform
5
- from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
4
+ from ...core import TensorTransform
5
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
6
6
 
7
7
 
8
8
  def _bool_ones_like(x):
@@ -126,7 +126,7 @@ def rprop_(
126
126
 
127
127
 
128
128
 
129
- class Rprop(Transform):
129
+ class Rprop(TensorTransform):
130
130
  """
131
131
  Resilient propagation. The update magnitude gets multiplied by `nplus` if gradient didn't change the sign,
132
132
  or `nminus` if it did. Then the update is applied with the sign of the current gradient.
@@ -165,7 +165,7 @@ class Rprop(Transform):
165
165
  super().__init__(defaults, uses_grad=False)
166
166
 
167
167
  @torch.no_grad
168
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
168
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
169
169
  step = self.global_state.get('step', 0)
170
170
  self.global_state['step'] = step + 1
171
171
 
@@ -178,7 +178,7 @@ class Rprop(Transform):
178
178
  )
179
179
 
180
180
  tensors = rprop_(
181
- tensors_ = as_tensorlist(tensors),
181
+ tensors_ = TensorList(tensors),
182
182
  prev_ = prev,
183
183
  allowed_ = allowed,
184
184
  magnitudes_ = magnitudes,
@@ -194,7 +194,7 @@ class Rprop(Transform):
194
194
  return tensors
195
195
 
196
196
 
197
- class ScaleLRBySignChange(Transform):
197
+ class ScaleLRBySignChange(TensorTransform):
198
198
  """
199
199
  learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign,
200
200
  or `nminus` if it did.
@@ -218,19 +218,19 @@ class ScaleLRBySignChange(Transform):
218
218
  ub=50.0,
219
219
  alpha=1.0,
220
220
  use_grad=False,
221
- target: Target = "update",
222
221
  ):
223
222
  defaults = dict(nplus=nplus, nminus=nminus, alpha=alpha, lb=lb, ub=ub, use_grad=use_grad)
224
- super().__init__(defaults, uses_grad=use_grad, target=target)
223
+ super().__init__(defaults, uses_grad=use_grad)
225
224
 
226
225
  @torch.no_grad
227
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
226
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
228
227
  step = self.global_state.get('step', 0)
229
228
  self.global_state['step'] = step + 1
230
229
 
231
- tensors = as_tensorlist(tensors)
232
- use_grad = settings[0]['use_grad']
233
- if use_grad: cur = as_tensorlist(grads)
230
+ tensors = TensorList(tensors)
231
+ if self._uses_grad:
232
+ assert grads is not None
233
+ cur = TensorList(grads)
234
234
  else: cur = tensors
235
235
 
236
236
  nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
@@ -252,7 +252,7 @@ class ScaleLRBySignChange(Transform):
252
252
  )
253
253
  return tensors
254
254
 
255
- class BacktrackOnSignChange(Transform):
255
+ class BacktrackOnSignChange(TensorTransform):
256
256
  """Negates or undoes update for parameters where where gradient or update sign changes.
257
257
 
258
258
  This is part of RProp update rule.
@@ -266,20 +266,21 @@ class BacktrackOnSignChange(Transform):
266
266
  Defaults to True.
267
267
 
268
268
  """
269
- def __init__(self, use_grad = False, backtrack = True, target: Target = 'update'):
270
- defaults = dict(use_grad=use_grad, backtrack=backtrack, target=target)
269
+ def __init__(self, use_grad = False, backtrack = True):
270
+ defaults = dict(use_grad=use_grad, backtrack=backtrack)
271
271
  super().__init__(defaults, uses_grad=use_grad)
272
272
 
273
273
  @torch.no_grad
274
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
274
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
275
275
  step = self.global_state.get('step', 0)
276
276
  self.global_state['step'] = step + 1
277
277
 
278
- tensors = as_tensorlist(tensors)
279
- use_grad = settings[0]['use_grad']
278
+ tensors = TensorList(tensors)
280
279
  backtrack = settings[0]['backtrack']
281
280
 
282
- if use_grad: cur = as_tensorlist(grads)
281
+ if self._uses_grad:
282
+ assert grads is not None
283
+ cur = TensorList(grads)
283
284
  else: cur = tensors
284
285
 
285
286
  tensors = backtrack_on_sign_change_(
@@ -292,54 +293,55 @@ class BacktrackOnSignChange(Transform):
292
293
 
293
294
  return tensors
294
295
 
295
- class SignConsistencyMask(Transform):
296
+ class SignConsistencyMask(TensorTransform):
296
297
  """
297
298
  Outputs a mask of sign consistency of current and previous inputs.
298
299
 
299
300
  The output is 0 for weights where input sign changed compared to previous input, 1 otherwise.
300
301
 
301
- Examples:
302
-
303
- GD that skips update for weights where gradient sign changed compared to previous gradient.
302
+ ### Examples:
304
303
 
305
- .. code-block:: python
304
+ GD that skips update for weights where gradient sign changed compared to previous gradient.
306
305
 
307
- opt = tz.Modular(
308
- model.parameters(),
309
- tz.m.Mul(tz.m.SignConsistencyMask()),
310
- tz.m.LR(1e-2)
311
- )
306
+ ```python
307
+ opt = tz.Modular(
308
+ model.parameters(),
309
+ tz.m.Mul(tz.m.SignConsistencyMask()),
310
+ tz.m.LR(1e-2)
311
+ )
312
+ ```
312
313
 
313
314
  """
314
- def __init__(self,target: Target = 'update'):
315
- super().__init__({}, uses_grad=False, target = target)
315
+ def __init__(self):
316
+ super().__init__()
316
317
 
317
318
  @torch.no_grad
318
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
319
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
319
320
  prev = unpack_states(states, tensors, 'prev', cls=TensorList)
320
321
  mask = prev.mul_(tensors).gt_(0)
321
322
  prev.copy_(tensors)
322
323
  return mask
323
324
 
324
325
 
325
- class SignConsistencyLRs(Transform):
326
+ class SignConsistencyLRs(TensorTransform):
326
327
  """Outputs per-weight learning rates based on consecutive sign consistency.
327
328
 
328
- The learning rate for a weight is multiplied by :code:`nplus` when two consecutive update signs are the same, otherwise it is multiplied by :code:`nplus`. The learning rates are bounded to be in :code:`(lb, ub)` range.
329
+ The learning rate for a weight is multiplied by ``nplus`` when two consecutive update signs are the same, otherwise it is multiplied by ``nplus``. The learning rates are bounded to be in ``(lb, ub)`` range.
329
330
 
330
- Examples:
331
+ ### Examples:
331
332
 
332
- GD scaled by consecutive gradient sign consistency
333
+ GD scaled by consecutive gradient sign consistency
333
334
 
334
- .. code-block:: python
335
+ ```python
335
336
 
336
- opt = tz.Modular(
337
- model.parameters(),
338
- tz.m.Mul(tz.m.SignConsistencyLRs()),
339
- tz.m.LR(1e-2)
340
- )
337
+ opt = tz.Modular(
338
+ model.parameters(),
339
+ tz.m.Mul(tz.m.SignConsistencyLRs()),
340
+ tz.m.LR(1e-2)
341
+ )
342
+ ```
341
343
 
342
- """
344
+ """
343
345
  def __init__(
344
346
  self,
345
347
  nplus: float = 1.2,
@@ -347,17 +349,16 @@ class SignConsistencyLRs(Transform):
347
349
  lb: float | None = 1e-6,
348
350
  ub: float | None = 50,
349
351
  alpha: float = 1,
350
- target: Target = 'update'
351
352
  ):
352
353
  defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
353
- super().__init__(defaults, uses_grad=False, target = target)
354
+ super().__init__(defaults, uses_grad=False)
354
355
 
355
356
  @torch.no_grad
356
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
357
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
357
358
  step = self.global_state.get('step', 0)
358
359
  self.global_state['step'] = step + 1
359
360
 
360
- target = as_tensorlist(tensors)
361
+ target = TensorList(tensors)
361
362
  nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
362
363
  prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)
363
364