torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,116 @@
1
+ # pylint:disable=not-callable
2
+ """all functions are from https://github.com/lixilinx/psgd_torch/blob/master/psgd.py"""
3
+ import math
4
+ import warnings
5
+
6
+ import torch
7
+
8
+ from ....core import Chainable, TensorTransform
9
+ from ._psgd_utils import _initialize_lra_state_
10
+ from .psgd import lift2single, precond_grad_lra, update_precond_lra_whiten
11
+
12
+ # matches
13
+ class PSGDLRAWhiten(TensorTransform):
14
+ """Low rank whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)
15
+
16
+ Args:
17
+ rank (int, optional):
18
+ Preconditioner has a diagonal part and a low rank part, whose rank is decided by this setting. Defaults to 10.
19
+ init_scale (float | None, optional):
20
+ initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
21
+ lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
22
+ betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
23
+ damping (float, optional):
24
+ adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.
25
+ grad_clip_max_norm (float, optional): clips norm of the update. Defaults to float("inf").
26
+ update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
27
+ concat_params (bool, optional):
28
+ if True, treats all parameters as concatenated to a single vector.
29
+ If False, each parameter is preconditioned separately. Defaults to True.
30
+ inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.
31
+
32
+ ###Examples:
33
+
34
+ Pure PSGD LRA:
35
+ ```py
36
+ optimizer = tz.Optimizer(
37
+ model.parameters(),
38
+ tz.m.LRAWhiten(),
39
+ tz.m.LR(1e-3),
40
+ )
41
+ ```
42
+
43
+ Momentum into preconditioner (whitens momentum):
44
+ ```py
45
+ optimizer = tz.Optimizer(
46
+ model.parameters(),
47
+ tz.m.EMA(0.9),
48
+ tz.m.LRAWhiten(),
49
+ tz.m.LR(1e-3),
50
+ )
51
+ ```
52
+
53
+ Updating the preconditioner from gradients and applying it to momentum:
54
+ ```py
55
+ optimizer = tz.Optimizer(
56
+ model.parameters(),
57
+ tz.m.LRAWhiten(inner=tz.m.EMA(0.9)),
58
+ tz.m.LR(1e-3),
59
+ )
60
+ ```
61
+
62
+ """
63
+ def __init__(
64
+ self,
65
+ rank: int = 10,
66
+ init_scale: float | None = None,
67
+ lr_preconditioner=0.1,
68
+ betaL=0.9,
69
+ damping=1e-9,
70
+ grad_clip_max_amp=float("inf"),
71
+ update_probability=1.0,
72
+
73
+ concat_params: bool = True,
74
+ inner: Chainable | None = None,
75
+ ):
76
+ defaults = locals().copy()
77
+ del defaults["inner"], defaults["self"]
78
+ super().__init__(defaults, concat_params=concat_params, inner=inner)
79
+
80
+ @torch.no_grad
81
+ def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
82
+ _initialize_lra_state_(tensor, state, setting)
83
+
84
+ @torch.no_grad
85
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
86
+
87
+ g = tensor.ravel().unsqueeze(1) # column vector
88
+
89
+ UVd = state["UVd"]
90
+ if UVd[2] is None: # initialize d on the fly
91
+ UVd[2] = (torch.mean(g**4) + setting["damping"]**4)**(-1/8) * torch.ones_like(g)
92
+
93
+ if torch.rand([]) < setting["update_probability"]: # update preconditioner
94
+ update_precond_lra_whiten(
95
+ UVd=UVd,
96
+ Luvd=state["Luvd"],
97
+ g=g,
98
+ lr=setting["lr_preconditioner"],
99
+ betaL=setting["betaL"],
100
+ damping=setting["damping"],
101
+ )
102
+
103
+ @torch.no_grad
104
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
105
+
106
+ g = tensor.ravel().unsqueeze(1)
107
+ pre_grad = precond_grad_lra(UVd=state["UVd"], g=g)
108
+
109
+ # norm clipping
110
+ grad_clip_max_amp = setting["grad_clip_max_amp"]
111
+ if grad_clip_max_amp < float("inf"): # clip preconditioned gradient
112
+ amp = torch.sqrt(torch.mean(pre_grad * pre_grad))
113
+ if amp > grad_clip_max_amp:
114
+ pre_grad *= grad_clip_max_amp/amp
115
+
116
+ return pre_grad.view_as(tensor)
@@ -1,45 +1,11 @@
1
- from operator import itemgetter
2
1
  from typing import Literal
3
2
 
4
3
  import torch
5
4
 
6
- from ...core import Module, Target, Transform, Chainable, Var, apply_transform
5
+ from ...core import TensorTransform, Chainable
7
6
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
8
- from ..functional import sqrt_centered_ema_sq_, sqrt_ema_sq_
9
-
10
-
11
- def rmsprop_(
12
- tensors_: TensorList,
13
- exp_avg_sq_: TensorList,
14
- smoothing: float | NumberList,
15
- eps: float | NumberList,
16
- debiased: bool,
17
- step: int,
18
- exp_avg_: TensorList | None = None,
19
- max_exp_avg_sq_: TensorList | None = None,
20
- pow: float = 2,
21
-
22
- # inner args
23
- inner: Module | None = None,
24
- params: list[torch.Tensor] | None = None,
25
- grads: list[torch.Tensor] | None = None,
26
- ):
27
- """returns `tensors_`"""
28
- if exp_avg_ is not None:
29
- sqrt_exp_avg_sq = sqrt_centered_ema_sq_(tensors=tensors_, exp_avg_=exp_avg_,
30
- exp_avg_sq_=exp_avg_sq_,max_exp_avg_sq_=max_exp_avg_sq_,
31
- beta=smoothing,debiased=debiased,step=step,pow=pow)
32
- else:
33
- sqrt_exp_avg_sq = sqrt_ema_sq_(tensors=tensors_,exp_avg_sq_=exp_avg_sq_,max_exp_avg_sq_=max_exp_avg_sq_,
34
- beta=smoothing,debiased=debiased,step=step,pow=pow)
35
-
36
- if inner is not None:
37
- assert params is not None
38
- tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
39
-
40
- return tensors_.div_(sqrt_exp_avg_sq.add_(eps))
41
-
42
- class RMSprop(Transform):
7
+
8
+ class RMSprop(TensorTransform):
43
9
  """Divides graient by EMA of gradient squares.
44
10
 
45
11
  This implementation is identical to :code:`torch.optim.RMSprop`.
@@ -48,7 +14,7 @@ class RMSprop(Transform):
48
14
  smoothing (float, optional): beta for exponential moving average of gradient squares. Defaults to 0.99.
49
15
  eps (float, optional): epsilon for division. Defaults to 1e-8.
50
16
  centered (bool, optional): whether to center EMA of gradient squares using an additional EMA. Defaults to False.
51
- debiased (bool, optional): applies Adam debiasing. Defaults to False.
17
+ debias (bool, optional): applies Adam debiasing. Defaults to False.
52
18
  amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
53
19
  pow (float, optional): power used in second momentum power and root. Defaults to 2.
54
20
  init (str, optional): how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".
@@ -60,44 +26,86 @@ class RMSprop(Transform):
60
26
  smoothing: float = 0.99,
61
27
  eps: float = 1e-8,
62
28
  centered: bool = False,
63
- debiased: bool = False,
29
+ debias: bool = False,
64
30
  amsgrad: bool = False,
65
- pow: float = 2,
66
31
  init: Literal["zeros", "update"] = "zeros",
32
+
67
33
  inner: Chainable | None = None,
34
+ exp_avg_sq_tfm: Chainable | None = None,
68
35
  ):
69
- defaults = dict(smoothing=smoothing,eps=eps,centered=centered,debiased=debiased,amsgrad=amsgrad,pow=pow,init=init)
70
- super().__init__(defaults=defaults, uses_grad=False)
71
-
72
- if inner is not None:
73
- self.set_child('inner', inner)
74
-
75
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
76
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
77
- smoothing, eps = unpack_dicts(settings, 'smoothing', 'eps', cls=NumberList)
78
- centered, debiased, amsgrad, pow, init = itemgetter('centered','debiased','amsgrad','pow','init')(settings[0])
79
-
80
- exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
81
- exp_avg = unpack_states(states, tensors, 'exp_avg', cls=TensorList) if centered else None
82
- max_exp_avg_sq = unpack_states(states, tensors, 'max_exp_avg_sq', cls=TensorList) if amsgrad else None
83
-
84
- if init == 'update' and step == 1:
85
- exp_avg_sq.set_([t**2 for t in tensors])
86
- if exp_avg is not None: exp_avg.set_([t.clone() for t in tensors])
87
-
88
- return rmsprop_(
89
- TensorList(tensors),
90
- exp_avg_sq_=exp_avg_sq,
91
- smoothing=smoothing,
92
- eps=eps,
93
- debiased=debiased,
94
- step=step,
95
- exp_avg_=exp_avg,
96
- max_exp_avg_sq_=max_exp_avg_sq,
97
- pow=pow,
98
-
99
- # inner args
100
- inner=self.children.get("inner", None),
101
- params=params,
102
- grads=grads,
103
- )
36
+ defaults = locals().copy()
37
+ del defaults['self'], defaults["inner"], defaults["exp_avg_sq_tfm"]
38
+ super().__init__(defaults, inner=inner)
39
+
40
+ self.set_child('exp_avg_sq', exp_avg_sq_tfm)
41
+
42
+ @torch.no_grad
43
+ def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
44
+ if setting["init"] == "zeros":
45
+ state["exp_avg_sq"] = torch.zeros_like(tensor)
46
+ if setting["centered"]: state["exp_avg"] = torch.zeros_like(tensor)
47
+ if setting["amsgrad"]: state["amsgrad"] = torch.zeros_like(tensor)
48
+
49
+ else:
50
+ state["exp_avg_sq"] = tensor ** 2
51
+ if setting["centered"]: state["exp_avg"] = tensor.clone()
52
+ if setting["amsgrad"]: state["amsgrad"] = tensor ** 2
53
+
54
+ @torch.no_grad
55
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
56
+ self.increment_counter("step", start = 0)
57
+ fs = settings[0]
58
+
59
+ exp_avg_sq = unpack_states(states, tensors, "exp_avg_sq", cls=TensorList)
60
+
61
+ # update exponential average
62
+ smoothing = NumberList(s["smoothing"] for s in settings)
63
+ exp_avg_sq.mul_(smoothing).addcmul_(tensors, tensors, value=1-smoothing)
64
+
65
+ # update mean estimate if centered
66
+ if fs["centered"]:
67
+ exp_avg = unpack_states(states, tensors, "exp_avg", cls=TensorList)
68
+ exp_avg.lerp_(tensors, 1-smoothing)
69
+
70
+ # amsgrad
71
+ if fs["amsgrad"]:
72
+ exp_avg_sq_max = unpack_states(states, tensors, "exp_avg_sq_max", cls=TensorList)
73
+ exp_avg_sq_max.maximum_(exp_avg_sq)
74
+
75
+ @torch.no_grad
76
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
77
+ tensors = TensorList(tensors)
78
+ step = self.global_state["step"] # 0 on 1st step
79
+ eps = NumberList(s["eps"] for s in settings)
80
+ fs = settings[0]
81
+
82
+ if fs["amsgrad"]: key = "max_exp_avg_sq"
83
+ else: key = "exp_avg_sq"
84
+ exp_avg_sq = TensorList(s[key] for s in states)
85
+
86
+ # load mean estimate if centered
87
+ exp_avg = None
88
+ if fs['centered']:
89
+ exp_avg = TensorList(s["exp_avg"] for s in states)
90
+
91
+ # debias exp_avg_sq and exp_avg
92
+ if fs["debias"]:
93
+ smoothing = NumberList(s["smoothing"] for s in settings)
94
+ bias_correction = 1 - (smoothing ** (step + 1))
95
+ exp_avg_sq = exp_avg_sq / bias_correction
96
+
97
+ if fs['centered']:
98
+ assert exp_avg is not None
99
+ exp_avg = exp_avg / bias_correction
100
+
101
+ # apply transform to potentially debiased exp_avg_sq
102
+ exp_avg_sq = TensorList(self.inner_step_tensors(
103
+ "exp_avg_sq", exp_avg_sq, params=params, grads=grads, loss=loss, clone=True, must_exist=False
104
+ ))
105
+
106
+ # center
107
+ if fs["centered"]:
108
+ assert exp_avg is not None
109
+ exp_avg_sq = exp_avg_sq.addcmul(exp_avg, exp_avg, value=-1)
110
+
111
+ return tensors.div_(exp_avg_sq.sqrt().add_(eps))
@@ -1,8 +1,8 @@
1
1
 
2
2
  import torch
3
3
 
4
- from ...core import Module, Target, Transform
5
- from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
4
+ from ...core import TensorTransform
5
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
6
6
 
7
7
 
8
8
  def _bool_ones_like(x):
@@ -126,7 +126,7 @@ def rprop_(
126
126
 
127
127
 
128
128
 
129
- class Rprop(Transform):
129
+ class Rprop(TensorTransform):
130
130
  """
131
131
  Resilient propagation. The update magnitude gets multiplied by `nplus` if gradient didn't change the sign,
132
132
  or `nminus` if it did. Then the update is applied with the sign of the current gradient.
@@ -165,7 +165,7 @@ class Rprop(Transform):
165
165
  super().__init__(defaults, uses_grad=False)
166
166
 
167
167
  @torch.no_grad
168
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
168
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
169
169
  step = self.global_state.get('step', 0)
170
170
  self.global_state['step'] = step + 1
171
171
 
@@ -178,7 +178,7 @@ class Rprop(Transform):
178
178
  )
179
179
 
180
180
  tensors = rprop_(
181
- tensors_ = as_tensorlist(tensors),
181
+ tensors_ = TensorList(tensors),
182
182
  prev_ = prev,
183
183
  allowed_ = allowed,
184
184
  magnitudes_ = magnitudes,
@@ -194,7 +194,7 @@ class Rprop(Transform):
194
194
  return tensors
195
195
 
196
196
 
197
- class ScaleLRBySignChange(Transform):
197
+ class ScaleLRBySignChange(TensorTransform):
198
198
  """
199
199
  learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign,
200
200
  or `nminus` if it did.
@@ -218,19 +218,19 @@ class ScaleLRBySignChange(Transform):
218
218
  ub=50.0,
219
219
  alpha=1.0,
220
220
  use_grad=False,
221
- target: Target = "update",
222
221
  ):
223
222
  defaults = dict(nplus=nplus, nminus=nminus, alpha=alpha, lb=lb, ub=ub, use_grad=use_grad)
224
- super().__init__(defaults, uses_grad=use_grad, target=target)
223
+ super().__init__(defaults, uses_grad=use_grad)
225
224
 
226
225
  @torch.no_grad
227
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
226
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
228
227
  step = self.global_state.get('step', 0)
229
228
  self.global_state['step'] = step + 1
230
229
 
231
- tensors = as_tensorlist(tensors)
232
- use_grad = settings[0]['use_grad']
233
- if use_grad: cur = as_tensorlist(grads)
230
+ tensors = TensorList(tensors)
231
+ if self._uses_grad:
232
+ assert grads is not None
233
+ cur = TensorList(grads)
234
234
  else: cur = tensors
235
235
 
236
236
  nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
@@ -252,7 +252,7 @@ class ScaleLRBySignChange(Transform):
252
252
  )
253
253
  return tensors
254
254
 
255
- class BacktrackOnSignChange(Transform):
255
+ class BacktrackOnSignChange(TensorTransform):
256
256
  """Negates or undoes update for parameters where where gradient or update sign changes.
257
257
 
258
258
  This is part of RProp update rule.
@@ -266,20 +266,21 @@ class BacktrackOnSignChange(Transform):
266
266
  Defaults to True.
267
267
 
268
268
  """
269
- def __init__(self, use_grad = False, backtrack = True, target: Target = 'update'):
270
- defaults = dict(use_grad=use_grad, backtrack=backtrack, target=target)
269
+ def __init__(self, use_grad = False, backtrack = True):
270
+ defaults = dict(use_grad=use_grad, backtrack=backtrack)
271
271
  super().__init__(defaults, uses_grad=use_grad)
272
272
 
273
273
  @torch.no_grad
274
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
274
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
275
275
  step = self.global_state.get('step', 0)
276
276
  self.global_state['step'] = step + 1
277
277
 
278
- tensors = as_tensorlist(tensors)
279
- use_grad = settings[0]['use_grad']
278
+ tensors = TensorList(tensors)
280
279
  backtrack = settings[0]['backtrack']
281
280
 
282
- if use_grad: cur = as_tensorlist(grads)
281
+ if self._uses_grad:
282
+ assert grads is not None
283
+ cur = TensorList(grads)
283
284
  else: cur = tensors
284
285
 
285
286
  tensors = backtrack_on_sign_change_(
@@ -292,54 +293,55 @@ class BacktrackOnSignChange(Transform):
292
293
 
293
294
  return tensors
294
295
 
295
- class SignConsistencyMask(Transform):
296
+ class SignConsistencyMask(TensorTransform):
296
297
  """
297
298
  Outputs a mask of sign consistency of current and previous inputs.
298
299
 
299
300
  The output is 0 for weights where input sign changed compared to previous input, 1 otherwise.
300
301
 
301
- Examples:
302
-
303
- GD that skips update for weights where gradient sign changed compared to previous gradient.
302
+ ### Examples:
304
303
 
305
- .. code-block:: python
304
+ GD that skips update for weights where gradient sign changed compared to previous gradient.
306
305
 
307
- opt = tz.Modular(
308
- model.parameters(),
309
- tz.m.Mul(tz.m.SignConsistencyMask()),
310
- tz.m.LR(1e-2)
311
- )
306
+ ```python
307
+ opt = tz.Optimizer(
308
+ model.parameters(),
309
+ tz.m.Mul(tz.m.SignConsistencyMask()),
310
+ tz.m.LR(1e-2)
311
+ )
312
+ ```
312
313
 
313
314
  """
314
- def __init__(self,target: Target = 'update'):
315
- super().__init__({}, uses_grad=False, target = target)
315
+ def __init__(self):
316
+ super().__init__()
316
317
 
317
318
  @torch.no_grad
318
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
319
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
319
320
  prev = unpack_states(states, tensors, 'prev', cls=TensorList)
320
321
  mask = prev.mul_(tensors).gt_(0)
321
322
  prev.copy_(tensors)
322
323
  return mask
323
324
 
324
325
 
325
- class SignConsistencyLRs(Transform):
326
+ class SignConsistencyLRs(TensorTransform):
326
327
  """Outputs per-weight learning rates based on consecutive sign consistency.
327
328
 
328
- The learning rate for a weight is multiplied by :code:`nplus` when two consecutive update signs are the same, otherwise it is multiplied by :code:`nplus`. The learning rates are bounded to be in :code:`(lb, ub)` range.
329
+ The learning rate for a weight is multiplied by ``nplus`` when two consecutive update signs are the same, otherwise it is multiplied by ``nplus``. The learning rates are bounded to be in ``(lb, ub)`` range.
329
330
 
330
- Examples:
331
+ ### Examples:
331
332
 
332
- GD scaled by consecutive gradient sign consistency
333
+ GD scaled by consecutive gradient sign consistency
333
334
 
334
- .. code-block:: python
335
+ ```python
335
336
 
336
- opt = tz.Modular(
337
- model.parameters(),
338
- tz.m.Mul(tz.m.SignConsistencyLRs()),
339
- tz.m.LR(1e-2)
340
- )
337
+ opt = tz.Optimizer(
338
+ model.parameters(),
339
+ tz.m.Mul(tz.m.SignConsistencyLRs()),
340
+ tz.m.LR(1e-2)
341
+ )
342
+ ```
341
343
 
342
- """
344
+ """
343
345
  def __init__(
344
346
  self,
345
347
  nplus: float = 1.2,
@@ -347,17 +349,16 @@ class SignConsistencyLRs(Transform):
347
349
  lb: float | None = 1e-6,
348
350
  ub: float | None = 50,
349
351
  alpha: float = 1,
350
- target: Target = 'update'
351
352
  ):
352
353
  defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
353
- super().__init__(defaults, uses_grad=False, target = target)
354
+ super().__init__(defaults, uses_grad=False)
354
355
 
355
356
  @torch.no_grad
356
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
357
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
357
358
  step = self.global_state.get('step', 0)
358
359
  self.global_state['step'] = step + 1
359
360
 
360
- target = as_tensorlist(tensors)
361
+ target = TensorList(tensors)
361
362
  nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
362
363
  prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)
363
364
 
@@ -1,10 +1,10 @@
1
1
  from contextlib import nullcontext
2
2
  import torch
3
- from ...utils import TensorList, NumberList
4
- from ...core import Module
3
+ from ...utils import TensorList, NumberList, unpack_dicts, unpack_states
4
+ from ...core import Transform
5
5
 
6
6
 
7
- class SAM(Module):
7
+ class SAM(Transform):
8
8
  """Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412
9
9
 
10
10
  SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
@@ -22,50 +22,51 @@ class SAM(Module):
22
22
  p (float, optional): norm of the SAM objective. Defaults to 2.
23
23
  asam (bool, optional):
24
24
  enables ASAM variant which makes perturbation relative to weight magnitudes.
25
- ASAM requires a much larger :code:`rho`, like 0.5 or 1.
26
- The :code:`tz.m.ASAM` class is idential to setting this argument to True, but
27
- it has larger :code:`rho` by default.
25
+ ASAM requires a much larger ``rho``, like 0.5 or 1.
26
+ The ``tz.m.ASAM`` class is idential to setting this argument to True, but
27
+ it has larger ``rho`` by default.
28
28
 
29
- Examples:
30
- SAM-SGD:
29
+ ### Examples:
31
30
 
32
- .. code-block:: python
31
+ SAM-SGD:
33
32
 
34
- opt = tz.Modular(
35
- model.parameters(),
36
- tz.m.SAM(),
37
- tz.m.LR(1e-2)
38
- )
33
+ ```py
34
+ opt = tz.Optimizer(
35
+ model.parameters(),
36
+ tz.m.SAM(),
37
+ tz.m.LR(1e-2)
38
+ )
39
+ ```
39
40
 
40
- SAM-Adam:
41
+ SAM-Adam:
41
42
 
42
- .. code-block:: python
43
-
44
- opt = tz.Modular(
45
- model.parameters(),
46
- tz.m.SAM(),
47
- tz.m.Adam(),
48
- tz.m.LR(1e-2)
49
- )
43
+ ```
44
+ opt = tz.Optimizer(
45
+ model.parameters(),
46
+ tz.m.SAM(),
47
+ tz.m.Adam(),
48
+ tz.m.LR(1e-2)
49
+ )
50
+ ```
50
51
 
51
52
  References:
52
- Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412. https://arxiv.org/abs/2010.01412#page=3.16
53
+ [Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412.](https://arxiv.org/abs/2010.01412#page=3.16)
53
54
  """
54
55
  def __init__(self, rho: float = 0.05, p: float = 2, eps=1e-10, asam=False):
55
56
  defaults = dict(rho=rho, p=p, eps=eps, asam=asam)
56
57
  super().__init__(defaults)
57
58
 
58
59
  @torch.no_grad
59
- def step(self, var):
60
+ def update_states(self, objective, states, settings):
60
61
 
61
- params = var.params
62
- closure = var.closure
63
- zero_grad = var.zero_grad
62
+ params = objective.params
63
+ closure = objective.closure
64
+ zero_grad = objective.zero_grad
64
65
  if closure is None: raise RuntimeError("SAM requires a closure passed to the optimizer step")
65
- p, rho = self.get_settings(var.params, 'p', 'rho', cls=NumberList)
66
- s = self.defaults
67
- eps = s['eps']
68
- asam = s['asam']
66
+ p, rho = unpack_dicts(settings, 'p', 'rho', cls=NumberList)
67
+ fs = settings[0]
68
+ eps = fs['eps']
69
+ asam = fs['asam']
69
70
 
70
71
  # 1/p + 1/q = 1
71
72
  # okay, authors of SAM paper, I will manually solve your equation
@@ -123,8 +124,7 @@ class SAM(Module):
123
124
 
124
125
  return sam_loss
125
126
 
126
- var.closure = sam_closure
127
- return var
127
+ objective.closure = sam_closure
128
128
 
129
129
  # different class because defaults for SAM are bad for ASAM
130
130
  class ASAM(SAM):
@@ -136,7 +136,7 @@ class ASAM(SAM):
136
136
  This implementation modifies the closure to return loss and calculate gradients
137
137
  of the SAM objective. All modules after this will use the modified objective.
138
138
 
139
- .. note::
139
+ Note:
140
140
  This module requires a closure passed to the optimizer step,
141
141
  as it needs to re-evaluate the loss and gradients at two points on each step.
142
142
 
@@ -144,20 +144,30 @@ class ASAM(SAM):
144
144
  rho (float, optional): Neighborhood size. Defaults to 0.05.
145
145
  p (float, optional): norm of the SAM objective. Defaults to 2.
146
146
 
147
- Examples:
148
- ASAM-Adam:
147
+ ### Examples:
148
+
149
+ ASAM-SGD:
149
150
 
150
- .. code-block:: python
151
+ ```py
152
+ opt = tz.Optimizer(
153
+ model.parameters(),
154
+ tz.m.ASAM(),
155
+ tz.m.LR(1e-2)
156
+ )
157
+ ```
151
158
 
152
- opt = tz.Modular(
153
- model.parameters(),
154
- tz.m.ASAM(),
155
- tz.m.Adam(),
156
- tz.m.LR(1e-2)
157
- )
159
+ ASAM-Adam:
158
160
 
161
+ ```
162
+ opt = tz.Optimizer(
163
+ model.parameters(),
164
+ tz.m.ASAM(),
165
+ tz.m.Adam(),
166
+ tz.m.LR(1e-2)
167
+ )
168
+ ```
159
169
  References:
160
- Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). Asam: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR. https://arxiv.org/abs/2102.11600
170
+ [Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). ASAM: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR.](https://arxiv.org/abs/2102.11600)
161
171
  """
162
172
  def __init__(self, rho: float = 0.5, p: float = 2, eps=1e-10):
163
173
  super().__init__(rho=rho, p=p, eps=eps, asam=True)