torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -2,123 +2,192 @@ from typing import Literal
2
2
 
3
3
  import torch
4
4
 
5
- from ...core import Module, apply
5
+ from ...core import Module, apply_transform, Chainable
6
6
  from ...utils import NumberList, TensorList, as_tensorlist
7
7
  from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
8
 
9
9
  class MatrixMomentum(Module):
10
+ """Second order momentum method.
11
+
12
+ Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.
13
+
14
+ .. note::
15
+ :code:`mu` needs to be tuned very carefully. It is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable.
16
+
17
+ .. note::
18
+ I have devised an adaptive version of this - :code:`tz.m.AdaptiveMatrixMomentum`, and it works well
19
+ without having to tune :code:`mu`.
20
+
21
+ .. note::
22
+ In most cases MatrixMomentum should be the first module in the chain because it relies on autograd.
23
+
24
+ .. note::
25
+ This module requires the a closure passed to the optimizer step,
26
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
27
+ The closure must accept a ``backward`` argument (refer to documentation).
28
+
29
+ Args:
30
+ mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
31
+ beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
32
+ hvp_method (str, optional):
33
+ Determines how Hessian-vector products are evaluated.
34
+
35
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
36
+ This requires creating a graph for the gradient.
37
+ - ``"forward"``: Use a forward finite difference formula to
38
+ approximate the HVP. This requires one extra gradient evaluation.
39
+ - ``"central"``: Use a central finite difference formula for a
40
+ more accurate HVP approximation. This requires two extra
41
+ gradient evaluations.
42
+ Defaults to "autograd".
43
+ h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
44
+ hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
45
+
46
+ Reference:
47
+ Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
10
48
  """
11
- May be useful for ill conditioned stochastic quadratic objectives but I need to test this.
12
- Evaluates hessian vector product on each step (via finite difference or autograd).
13
49
 
14
- `mu` is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable.
15
-
16
- Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
17
- """
18
- def __init__(self, mu=0.1, beta:float=1, hvp_mode: Literal['autograd', 'forward', 'central'] = 'forward', h=1e-3, hvp_tfm=None):
19
- defaults = dict(mu=mu, beta=beta, hvp_mode=hvp_mode, h=h)
50
+ def __init__(
51
+ self,
52
+ mu=0.1,
53
+ beta: float = 1,
54
+ hvp_method: Literal["autograd", "forward", "central"] = "autograd",
55
+ h: float = 1e-3,
56
+ hvp_tfm: Chainable | None = None,
57
+ ):
58
+ defaults = dict(mu=mu, beta=beta, hvp_method=hvp_method, h=h)
20
59
  super().__init__(defaults)
21
60
 
22
61
  if hvp_tfm is not None:
23
62
  self.set_child('hvp_tfm', hvp_tfm)
24
63
 
25
- @torch.no_grad
26
- def step(self, vars):
27
- assert vars.closure is not None
28
- prev_update = self.get_state('prev_update', params=vars.params, cls=TensorList)
29
- hvp_mode = self.settings[vars.params[0]]['hvp_mode']
30
- h = self.settings[vars.params[0]]['h']
64
+ def reset_for_online(self):
65
+ super().reset_for_online()
66
+ self.clear_state_keys('prev_update')
31
67
 
32
- mu,beta = self.get_settings('mu','beta', params=vars.params, cls=NumberList)
33
-
34
- if hvp_mode == 'autograd':
35
- with torch.enable_grad():
36
- grad = vars.get_grad(create_graph=True)
37
- hvp_ = TensorList(hvp(vars.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
68
+ @torch.no_grad
69
+ def update(self, var):
70
+ assert var.closure is not None
71
+ prev_update = self.get_state(var.params, 'prev_update')
72
+ hvp_method = self.settings[var.params[0]]['hvp_method']
73
+ h = self.settings[var.params[0]]['h']
38
74
 
39
- elif hvp_mode == 'forward':
40
- vars.get_grad()
41
- l, hvp_ = hvp_fd_forward(vars.closure, vars.params, vec=prev_update, g_0=vars.grad, h=h, normalize=True)
42
- if vars.loss_approx is None: vars.loss_approx = l
75
+ Hvp, _ = self.Hvp(prev_update, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
76
+ Hvp = [t.detach() for t in Hvp]
43
77
 
44
- elif hvp_mode == 'central':
45
- l, hvp_ = hvp_fd_central(vars.closure, vars.params, vec=prev_update, h=h, normalize=True)
46
- if vars.loss_approx is None: vars.loss_approx = l
78
+ if 'hvp_tfm' in self.children:
79
+ Hvp = TensorList(apply_transform(self.children['hvp_tfm'], Hvp, params=var.params, grads=var.grad, var=var))
47
80
 
48
- else:
49
- raise ValueError(hvp_mode)
81
+ self.store(var.params, "Hvp", Hvp)
50
82
 
51
- if 'hvp_tfm' in self.children:
52
- hvp_ = TensorList(apply(self.children['hvp_tfm'], hvp_, params=vars.params, grads=vars.grad, vars=vars))
53
83
 
54
- update = TensorList(vars.get_update())
84
+ @torch.no_grad
85
+ def apply(self, var):
86
+ update = TensorList(var.get_update())
87
+ Hvp, prev_update = self.get_state(var.params, 'Hvp', 'prev_update', cls=TensorList)
88
+ mu,beta = self.get_settings(var.params, 'mu','beta', cls=NumberList)
55
89
 
56
- hvp_ = as_tensorlist(hvp_)
57
- update.add_(prev_update - hvp_*mu)
90
+ update.add_(prev_update - Hvp*mu)
58
91
  prev_update.set_(update * beta)
59
- vars.update = update
60
- return vars
92
+ var.update = update
93
+ return var
61
94
 
62
95
 
63
96
  class AdaptiveMatrixMomentum(Module):
97
+ """Second order momentum method.
98
+
99
+ Matrix momentum is useful for convex objectives, also for some reason it has very good generalization on elastic net logistic regression.
100
+
101
+ .. note::
102
+ In most cases MatrixMomentum should be the first module in the chain because it relies on autograd.
103
+
104
+ .. note::
105
+ This module requires the a closure passed to the optimizer step,
106
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
107
+ The closure must accept a ``backward`` argument (refer to documentation).
108
+
109
+
110
+ Args:
111
+ mu_mul (float, optional): multiplier to the estimated mu. Defaults to 1.
112
+ beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
113
+ hvp_method (str, optional):
114
+ Determines how Hessian-vector products are evaluated.
115
+
116
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
117
+ This requires creating a graph for the gradient.
118
+ - ``"forward"``: Use a forward finite difference formula to
119
+ approximate the HVP. This requires one extra gradient evaluation.
120
+ - ``"central"``: Use a central finite difference formula for a
121
+ more accurate HVP approximation. This requires two extra
122
+ gradient evaluations.
123
+ Defaults to "autograd".
124
+ h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
125
+ hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
126
+
127
+ Reference:
128
+ Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
64
129
  """
65
- Mu here is estimated as ||s_k||/||y_k||.
66
- """
67
- def __init__(self, mu_mul:float=1, beta:float=1, eps=1e-4, hvp_mode: Literal['autograd', 'forward', 'central'] = 'forward', h=1e-3, hvp_tfm=None):
68
- defaults = dict(mu_mul=mu_mul, beta=beta, hvp_mode=hvp_mode, h=h, eps=eps)
130
+
131
+ def __init__(
132
+ self,
133
+ mu_mul: float = 1,
134
+ beta: float = 1,
135
+ eps=1e-4,
136
+ hvp_method: Literal["autograd", "forward", "central"] = "autograd",
137
+ h: float = 1e-3,
138
+ hvp_tfm: Chainable | None = None,
139
+ ):
140
+ defaults = dict(mu_mul=mu_mul, beta=beta, hvp_method=hvp_method, h=h, eps=eps)
69
141
  super().__init__(defaults)
70
142
 
71
143
  if hvp_tfm is not None:
72
144
  self.set_child('hvp_tfm', hvp_tfm)
73
145
 
146
+ def reset_for_online(self):
147
+ super().reset_for_online()
148
+ self.clear_state_keys('prev_params', 'prev_grad')
149
+
74
150
  @torch.no_grad
75
- def step(self, vars):
76
- assert vars.closure is not None
77
- prev_update, prev_params, prev_grad = self.get_state('prev_update', 'prev_params', 'prev_grad', params=vars.params, cls=TensorList)
151
+ def update(self, var):
152
+ assert var.closure is not None
153
+ prev_update, prev_params, prev_grad = self.get_state(var.params, 'prev_update', 'prev_params', 'prev_grad', cls=TensorList)
78
154
 
79
- settings = self.settings[vars.params[0]]
80
- hvp_mode = settings['hvp_mode']
155
+ settings = self.settings[var.params[0]]
156
+ hvp_method = settings['hvp_method']
81
157
  h = settings['h']
82
158
  eps = settings['eps']
83
159
 
84
- mu_mul, beta = self.get_settings('mu_mul','beta', params=vars.params, cls=NumberList)
85
-
86
- if hvp_mode == 'autograd':
87
- with torch.enable_grad():
88
- grad = vars.get_grad(create_graph=True)
89
- hvp_ = TensorList(hvp(vars.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
90
-
91
- elif hvp_mode == 'forward':
92
- vars.get_grad()
93
- l, hvp_ = hvp_fd_forward(vars.closure, vars.params, vec=prev_update, g_0=vars.grad, h=h, normalize=True)
94
- if vars.loss_approx is None: vars.loss_approx = l
160
+ mu_mul = NumberList(self.settings[p]['mu_mul'] for p in var.params)
95
161
 
96
- elif hvp_mode == 'central':
97
- l, hvp_ = hvp_fd_central(vars.closure, vars.params, vec=prev_update, h=h, normalize=True)
98
- if vars.loss_approx is None: vars.loss_approx = l
99
-
100
- else:
101
- raise ValueError(hvp_mode)
162
+ Hvp, _ = self.Hvp(prev_update, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
163
+ Hvp = [t.detach() for t in Hvp]
102
164
 
103
165
  if 'hvp_tfm' in self.children:
104
- hvp_ = TensorList(apply(self.children['hvp_tfm'], hvp_, params=vars.params, grads=vars.grad, vars=vars))
166
+ Hvp = TensorList(apply_transform(self.children['hvp_tfm'], Hvp, params=var.params, grads=var.grad, var=var))
105
167
 
106
168
  # adaptive part
107
- update = TensorList(vars.get_update())
108
-
109
- s_k = vars.params - prev_params
110
- prev_params.copy_(vars.params)
169
+ s_k = var.params - prev_params
170
+ prev_params.copy_(var.params)
111
171
 
112
- assert vars.grad is not None
113
- y_k = vars.grad - prev_grad
114
- prev_grad.copy_(vars.grad)
172
+ if hvp_method != 'central': assert var.grad is not None
173
+ grad = var.get_grad()
174
+ y_k = grad - prev_grad
175
+ prev_grad.copy_(grad)
115
176
 
116
177
  ada_mu = (s_k.global_vector_norm() / (y_k.global_vector_norm() + eps)) * mu_mul
117
178
 
118
- # matrix momentum uppdate
119
- hvp_ = as_tensorlist(hvp_)
120
- update.add_(prev_update - hvp_*ada_mu)
179
+ self.store(var.params, ['Hvp', 'ada_mu'], [Hvp, ada_mu])
180
+
181
+ @torch.no_grad
182
+ def apply(self, var):
183
+ Hvp, ada_mu = self.get_state(var.params, 'Hvp', 'ada_mu')
184
+ Hvp = as_tensorlist(Hvp)
185
+ beta = NumberList(self.settings[p]['beta'] for p in var.params)
186
+ update = TensorList(var.get_update())
187
+ prev_update = TensorList(self.state[p]['prev_update'] for p in var.params)
188
+
189
+ update.add_(prev_update - Hvp*ada_mu)
121
190
  prev_update.set_(update * beta)
122
- vars.update = update
123
- return vars
191
+ var.update = update
192
+ return var
124
193
 
@@ -3,11 +3,22 @@ from typing import Literal
3
3
  import torch
4
4
 
5
5
  from ...core import Target, Transform
6
- from ...utils import NumberList, TensorList
6
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
7
7
  from .ema import EMA
8
8
 
9
9
 
10
10
  class HeavyBall(EMA):
11
+ """Polyak's momentum (heavy-ball method).
12
+
13
+ Args:
14
+ momentum (float, optional): momentum (beta). Defaults to 0.9.
15
+ dampening (float, optional): momentum dampening. Defaults to 0.
16
+ debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
17
+ lerp (bool, optional):
18
+ whether to use linear interpolation, if True, this becomes exponential moving average. Defaults to False.
19
+ ema_init (str, optional): initial values for the EMA, "zeros" or "update".
20
+ target (Target, optional): target to apply EMA to. Defaults to 'update'.
21
+ """
11
22
  def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update', target: Target = 'update'):
12
23
  super().__init__(momentum=momentum, dampening=dampening, debiased=debiased, lerp=lerp, ema_init=ema_init, target=target)
13
24
 
@@ -30,14 +41,24 @@ def nag_(
30
41
 
31
42
 
32
43
  class NAG(Transform):
44
+ """Nesterov accelerated gradient method (nesterov momentum).
45
+
46
+ Args:
47
+ momentum (float, optional): momentum (beta). Defaults to 0.9.
48
+ dampening (float, optional): momentum dampening. Defaults to 0.
49
+ lerp (bool, optional):
50
+ whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
51
+ target (Target, optional): target to apply EMA to. Defaults to 'update'.
52
+ """
33
53
  def __init__(self, momentum:float=0.9, dampening:float=0, lerp=False, target: Target = 'update'):
34
54
  defaults = dict(momentum=momentum,dampening=dampening, lerp=lerp)
35
55
  super().__init__(defaults, uses_grad=False, target=target)
36
56
 
37
57
  @torch.no_grad
38
- def transform(self, tensors, params, grads, vars):
39
- velocity = self.get_state('velocity', params=params, cls=TensorList)
58
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
59
+ velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
40
60
  lerp = self.settings[params[0]]['lerp']
41
61
 
42
- momentum,dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
62
+ momentum,dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
43
63
  return nag_(TensorList(tensors), velocity_=velocity,momentum=momentum,dampening=dampening,lerp=lerp)
64
+
@@ -7,7 +7,7 @@ from .accumulate import (
7
7
  )
8
8
  from .binary import (
9
9
  Add,
10
- BinaryOperation,
10
+ BinaryOperationBase,
11
11
  Clip,
12
12
  CopyMagnitude,
13
13
  CopySign,
@@ -27,37 +27,12 @@ from .binary import (
27
27
  Sub,
28
28
  Threshold,
29
29
  )
30
- from .debug import PrintShape, PrintUpdate
31
- from .misc import (
32
- DivByLoss,
33
- Dropout,
34
- FillLoss,
35
- GradientAccumulation,
36
- GradSign,
37
- GraftGradToUpdate,
38
- GraftToGrad,
39
- GraftToParams,
40
- LastAbsoluteRatio,
41
- LastDifference,
42
- LastGradDifference,
43
- LastProduct,
44
- LastRatio,
45
- MulByLoss,
46
- Multistep,
47
- NegateOnLossIncrease,
48
- NoiseSign,
49
- Previous,
50
- Relative,
51
- Sequential,
52
- UpdateSign,
53
- WeightDropout,
54
- )
55
30
  from .multi import (
56
31
  ClipModules,
57
32
  DivModules,
58
33
  GraftModules,
59
34
  LerpModules,
60
- MultiOperation,
35
+ MultiOperationBase,
61
36
  PowModules,
62
37
  SubModules,
63
38
  )
@@ -66,13 +41,11 @@ from .reduce import (
66
41
  Mean,
67
42
  MinimumModules,
68
43
  Prod,
69
- ReduceOperation,
44
+ ReduceOperationBase,
70
45
  Sum,
71
46
  WeightedMean,
72
47
  WeightedSum,
73
48
  )
74
- from .split import Split
75
- from .switch import Alternate, Switch
76
49
  from .unary import (
77
50
  Abs,
78
51
  CustomUnaryOperation,
@@ -97,7 +70,6 @@ from .utility import (
97
70
  Randn,
98
71
  RandomSample,
99
72
  Uniform,
100
- Update,
101
73
  UpdateToNone,
102
74
  Zeros,
103
75
  )
@@ -1,65 +1,91 @@
1
- from collections import deque
2
- from operator import itemgetter
3
- from typing import Literal
4
-
5
1
  import torch
6
2
 
7
3
  from ...core import Target, Transform
8
- from ...utils import TensorList, NumberList
4
+ from ...utils import TensorList, unpack_states
9
5
 
10
6
  class AccumulateSum(Transform):
7
+ """Accumulates sum of all past updates.
8
+
9
+ Args:
10
+ decay (float, optional): decays the accumulator. Defaults to 0.
11
+ target (Target, optional): target. Defaults to 'update'.
12
+ """
11
13
  def __init__(self, decay: float = 0, target: Target = 'update',):
12
14
  defaults = dict(decay=decay)
13
15
  super().__init__(defaults, uses_grad=False, target=target)
14
16
 
15
17
  @torch.no_grad
16
- def transform(self, tensors, params, grads, vars):
17
- sum = self.get_state('sum', params=params, cls=TensorList)
18
- decay = self.get_settings('decay', params=params, cls=NumberList)
19
- return sum.add_(tensors).lazy_mul(1-decay, clone=True)
18
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
19
+ sum = unpack_states(states, tensors, 'sum', cls=TensorList)
20
+ decay = [1-s['decay'] for s in settings]
21
+ return sum.add_(tensors).lazy_mul(decay, clone=True)
20
22
 
21
23
  class AccumulateMean(Transform):
24
+ """Accumulates mean of all past updates.
25
+
26
+ Args:
27
+ decay (float, optional): decays the accumulator. Defaults to 0.
28
+ target (Target, optional): target. Defaults to 'update'.
29
+ """
22
30
  def __init__(self, decay: float = 0, target: Target = 'update',):
23
31
  defaults = dict(decay=decay)
24
32
  super().__init__(defaults, uses_grad=False, target=target)
25
33
 
26
34
  @torch.no_grad
27
- def transform(self, tensors, params, grads, vars):
35
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
28
36
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
29
- mean = self.get_state('mean', params=params, cls=TensorList)
30
- decay = self.get_settings('decay', params=params, cls=NumberList)
31
- return mean.add_(tensors).lazy_mul(1-decay, clone=True).div_(step)
37
+ mean = unpack_states(states, tensors, 'mean', cls=TensorList)
38
+ decay = [1-s['decay'] for s in settings]
39
+ return mean.add_(tensors).lazy_mul(decay, clone=True).div_(step)
32
40
 
33
41
  class AccumulateProduct(Transform):
42
+ """Accumulates product of all past updates.
43
+
44
+ Args:
45
+ decay (float, optional): decays the accumulator. Defaults to 0.
46
+ target (Target, optional): target. Defaults to 'update'.
47
+ """
34
48
  def __init__(self, decay: float = 0, target: Target = 'update',):
35
49
  defaults = dict(decay=decay)
36
50
  super().__init__(defaults, uses_grad=False, target=target)
37
51
 
38
52
  @torch.no_grad
39
- def transform(self, tensors, params, grads, vars):
40
- prod = self.get_state('prod', params=params, cls=TensorList)
41
- decay = self.get_settings('decay', params=params, cls=NumberList)
42
- return prod.mul_(tensors).lazy_mul(1-decay, clone=True)
53
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
54
+ prod = unpack_states(states, tensors, 'prod', cls=TensorList)
55
+ decay = [1-s['decay'] for s in settings]
56
+ return prod.mul_(tensors).lazy_mul(decay, clone=True)
43
57
 
44
58
  class AccumulateMaximum(Transform):
59
+ """Accumulates maximum of all past updates.
60
+
61
+ Args:
62
+ decay (float, optional): decays the accumulator. Defaults to 0.
63
+ target (Target, optional): target. Defaults to 'update'.
64
+ """
45
65
  def __init__(self, decay: float = 0, target: Target = 'update',):
46
66
  defaults = dict(decay=decay)
47
67
  super().__init__(defaults, uses_grad=False, target=target)
48
68
 
49
69
  @torch.no_grad
50
- def transform(self, tensors, params, grads, vars):
51
- maximum = self.get_state('maximum', params=params, cls=TensorList)
52
- decay = self.get_settings('decay', params=params, cls=NumberList)
53
- return maximum.maximum_(tensors).lazy_mul(1-decay, clone=True)
70
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
71
+ maximum = unpack_states(states, tensors, 'maximum', cls=TensorList)
72
+ decay = [1-s['decay'] for s in settings]
73
+ return maximum.maximum_(tensors).lazy_mul(decay, clone=True)
54
74
 
55
75
  class AccumulateMinimum(Transform):
76
+ """Accumulates minimum of all past updates.
77
+
78
+ Args:
79
+ decay (float, optional): decays the accumulator. Defaults to 0.
80
+ target (Target, optional): target. Defaults to 'update'.
81
+ """
56
82
  def __init__(self, decay: float = 0, target: Target = 'update',):
57
83
  defaults = dict(decay=decay)
58
84
  super().__init__(defaults, uses_grad=False, target=target)
59
85
 
60
86
  @torch.no_grad
61
- def transform(self, tensors, params, grads, vars):
62
- minimum = self.get_state('minimum', params=params, cls=TensorList)
63
- decay = self.get_settings('decay', params=params, cls=NumberList)
64
- return minimum.minimum_(tensors).lazy_mul(1-decay, clone=True)
87
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
88
+ minimum = unpack_states(states, tensors, 'minimum', cls=TensorList)
89
+ decay = [1-s['decay'] for s in settings]
90
+ return minimum.minimum_(tensors).lazy_mul(decay, clone=True)
65
91