torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -0,0 +1,25 @@
1
+ from collections import deque
2
+
3
+ import torch
4
+
5
+ from ...core import Module
6
+ from ...utils.tensorlist import Distributions
7
+
8
+ class PrintUpdate(Module):
9
+ def __init__(self, text = 'update = ', print_fn = print):
10
+ defaults = dict(text=text, print_fn=print_fn)
11
+ super().__init__(defaults)
12
+
13
+ def step(self, vars):
14
+ self.settings[vars.params[0]]["print_fn"](f'{self.settings[vars.params[0]]["text"]}{vars.update}')
15
+ return vars
16
+
17
+ class PrintShape(Module):
18
+ def __init__(self, text = 'shapes = ', print_fn = print):
19
+ defaults = dict(text=text, print_fn=print_fn)
20
+ super().__init__(defaults)
21
+
22
+ def step(self, vars):
23
+ shapes = [u.shape for u in vars.update] if vars.update is not None else None
24
+ self.settings[vars.params[0]]["print_fn"](f'{self.settings[vars.params[0]]["text"]}{shapes}')
25
+ return vars
@@ -0,0 +1,419 @@
1
+ from collections import deque
2
+ from collections.abc import Iterable
3
+ from operator import itemgetter
4
+ from typing import Literal
5
+
6
+ import torch
7
+
8
+ from ...core import Chainable, Module, TensorwiseTransform, Target, Transform, Vars
9
+ from ...utils import Distributions, NumberList, TensorList
10
+
11
+
12
+ class Previous(TensorwiseTransform):
13
+ """Maintains an update from n steps back, for example if n=1, returns previous update"""
14
+ def __init__(self, n=1, target: Target = 'update'):
15
+ defaults = dict(n=n)
16
+ super().__init__(uses_grad=False, defaults=defaults, target=target)
17
+
18
+
19
+ @torch.no_grad
20
+ def transform(self, tensor, param, grad, vars):
21
+ n = self.settings[param]['n']
22
+ state = self.state[param]
23
+
24
+ if 'history' not in state:
25
+ state['history'] = deque(maxlen=n+1)
26
+
27
+ state['history'].append(tensor)
28
+
29
+ return state['history'][0]
30
+
31
+
32
+ class LastDifference(Transform):
33
+ """Difference between past two updates."""
34
+ def __init__(self,target: Target = 'update'):
35
+ super().__init__({}, uses_grad=False, target=target)
36
+
37
+ @torch.no_grad
38
+ def transform(self, tensors, params, grads, vars):
39
+ prev_target = self.get_state('prev_target', params=params) # initialized to 0
40
+ difference = torch._foreach_sub(tensors, prev_target)
41
+ for p, c in zip(prev_target, tensors): p.set_(c)
42
+ return difference
43
+
44
+ class LastGradDifference(Module):
45
+ """Difference between past two grads."""
46
+ def __init__(self):
47
+ super().__init__({})
48
+
49
+ @torch.no_grad
50
+ def step(self, vars):
51
+ grad = vars.get_grad()
52
+ prev_grad = self.get_state('prev_grad', params=vars.params) # initialized to 0
53
+ difference = torch._foreach_sub(grad, prev_grad)
54
+ for p, c in zip(prev_grad, grad): p.set_(c)
55
+ vars.update = list(difference)
56
+ return vars
57
+
58
+
59
+ class LastProduct(Transform):
60
+ """Difference between past two updates."""
61
+ def __init__(self,target: Target = 'update'):
62
+ super().__init__({}, uses_grad=False, target=target)
63
+
64
+ @torch.no_grad
65
+ def transform(self, tensors, params, grads, vars):
66
+ prev_target = self.get_state('prev_target', params=params, init=torch.ones_like) # initialized to 1 for prod
67
+ prod = torch._foreach_mul(tensors, prev_target)
68
+ for p, c in zip(prev_target, tensors): p.set_(c)
69
+ return prod
70
+
71
+ class LastRatio(Transform):
72
+ """Ratio between past two updates."""
73
+ def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', target: Target = 'update'):
74
+ defaults = dict(numerator=numerator)
75
+ super().__init__(defaults, uses_grad=False, target=target)
76
+
77
+ @torch.no_grad
78
+ def transform(self, tensors, params, grads, vars):
79
+ prev_target = self.get_state('prev_target', params=params, init = torch.ones_like) # initialized to ones
80
+ numerator = self.settings[params[0]]['numerator']
81
+ if numerator == 'cur': ratio = torch._foreach_div(tensors, prev_target)
82
+ else: ratio = torch._foreach_div(prev_target, tensors)
83
+ for p, c in zip(prev_target, tensors): p.set_(c)
84
+ return ratio
85
+
86
+ class LastAbsoluteRatio(Transform):
87
+ """Ratio between absolute values of past two updates."""
88
+ def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8, target: Target = 'update'):
89
+ defaults = dict(numerator=numerator, eps=eps)
90
+ super().__init__(defaults, uses_grad=False, target=target)
91
+
92
+ @torch.no_grad
93
+ def transform(self, tensors, params, grads, vars):
94
+ prev_target = self.get_state('prev_target', params=params, init = torch.ones_like) # initialized to 0
95
+ numerator = self.settings[params[0]]['numerator']
96
+ eps = self.get_settings('eps', params=params, cls = NumberList)
97
+
98
+ torch._foreach_abs_(tensors)
99
+ torch._foreach_clamp_min_(prev_target, eps)
100
+
101
+ if numerator == 'cur': ratio = torch._foreach_div(tensors, prev_target)
102
+ else: ratio = torch._foreach_div(prev_target, tensors)
103
+ for p, c in zip(prev_target, tensors): p.set_(c)
104
+ return ratio
105
+
106
+ class GradSign(Transform):
107
+ """copy gradient sign to update."""
108
+ def __init__(self, target: Target = 'update'):
109
+ super().__init__({}, uses_grad=True, target=target)
110
+
111
+ @torch.no_grad
112
+ def transform(self, tensors, params, grads, vars):
113
+ assert grads is not None
114
+ return [t.copysign_(g) for t,g in zip(tensors, grads)]
115
+
116
+ class UpdateSign(Transform):
117
+ """use per-weight magnitudes from grad while using sign from update."""
118
+ def __init__(self, target: Target = 'update'):
119
+ super().__init__({}, uses_grad=True, target=target)
120
+
121
+ @torch.no_grad
122
+ def transform(self, tensors, params, grads, vars):
123
+ assert grads is not None
124
+ return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place
125
+
126
+ class GraftToGrad(Transform):
127
+ """use gradient norm and update direction."""
128
+ def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
129
+ defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
130
+ super().__init__(defaults, uses_grad=True, target=target)
131
+
132
+ @torch.no_grad
133
+ def transform(self, tensors, params, grads, vars):
134
+ assert grads is not None
135
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[params[0]])
136
+ return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)
137
+
138
+ class GraftGradToUpdate(Transform):
139
+ """use update norm and gradient direction."""
140
+ def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
141
+ defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
142
+ super().__init__(defaults, uses_grad=True, target=target)
143
+
144
+ @torch.no_grad
145
+ def transform(self, tensors, params, grads, vars):
146
+ assert grads is not None
147
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[params[0]])
148
+ return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)
149
+
150
+
151
+ class GraftToParams(Transform):
152
+ """makes update norm be set to parameter norm, but norm won't go below eps"""
153
+ def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-4, target: Target = 'update'):
154
+ defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
155
+ super().__init__(defaults, uses_grad=False, target=target)
156
+
157
+ @torch.no_grad
158
+ def transform(self, tensors, params, grads, vars):
159
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[params[0]])
160
+ return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)
161
+
162
+ class Relative(Transform):
163
+ """multiplies update by absolute parameter values to make it relative to their magnitude, min_value is minimum value to avoid getting stuck at 0"""
164
+ def __init__(self, min_value:float = 1e-4, target: Target = 'update'):
165
+ defaults = dict(min_value=min_value)
166
+ super().__init__(defaults, uses_grad=False, target=target)
167
+
168
+ @torch.no_grad
169
+ def transform(self, tensors, params, grads, vars):
170
+ mul = TensorList(params).abs().clamp_(self.get_settings('min_value', params=params))
171
+ torch._foreach_mul_(tensors, mul)
172
+ return tensors
173
+
174
+ class FillLoss(Module):
175
+ """makes tensors filled with loss value times alpha"""
176
+ def __init__(self, alpha: float = 1, backward: bool = True):
177
+ defaults = dict(alpha=alpha, backward=backward)
178
+ super().__init__(defaults)
179
+
180
+ @torch.no_grad
181
+ def step(self, vars):
182
+ alpha = self.get_settings('alpha', params=vars.params)
183
+ loss = vars.get_loss(backward=self.settings[vars.params[0]]['backward'])
184
+ vars.update = [torch.full_like(p, loss*a) for p,a in zip(vars.params, alpha)]
185
+ return vars
186
+
187
+ class MulByLoss(Transform):
188
+ """multiplies update by loss times alpha"""
189
+ def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True, target: Target = 'update'):
190
+ defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
191
+ super().__init__(defaults, uses_grad=False, target=target)
192
+
193
+ @torch.no_grad
194
+ def transform(self, tensors, params, grads, vars): #vars used for loss
195
+ alpha, min_value = self.get_settings('alpha', 'min_value', params=params)
196
+ loss = vars.get_loss(backward=self.settings[params[0]]['backward'])
197
+ mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
198
+ torch._foreach_mul_(tensors, mul)
199
+ return tensors
200
+
201
+ class DivByLoss(Transform):
202
+ """divides update by loss times alpha"""
203
+ def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True, target: Target = 'update'):
204
+ defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
205
+ super().__init__(defaults, uses_grad=False, target=target)
206
+
207
+ @torch.no_grad
208
+ def transform(self, tensors, params, grads, vars): #vars used for loss
209
+ alpha, min_value = self.get_settings('alpha', 'min_value', params=params)
210
+ loss = vars.get_loss(backward=self.settings[params[0]]['backward'])
211
+ mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
212
+ torch._foreach_div_(tensors, mul)
213
+ return tensors
214
+
215
+
216
+
217
+ def _sequential_step(self: Module, vars: Vars, sequential: bool):
218
+ params = vars.params
219
+ steps = self.settings[params[0]]['steps']
220
+
221
+ if sequential: modules = self.get_children_sequence()
222
+ else: modules = [self.children['module']] * steps
223
+
224
+ if vars.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
225
+
226
+ # store original params unless this is last module and can update params directly
227
+ params_before_steps = None if (vars.is_last and vars.last_module_lrs is None) else [p.clone() for p in params]
228
+
229
+ # first step - pass vars as usual
230
+ vars = modules[0].step(vars)
231
+ new_vars = vars
232
+
233
+ # subsequent steps - update parameters and create new vars
234
+ if len(modules) > 1:
235
+ for m in modules[1:]:
236
+
237
+ # update params
238
+ if (not new_vars.skip_update):
239
+ if new_vars.last_module_lrs is not None:
240
+ torch._foreach_mul_(new_vars.get_update(), new_vars.last_module_lrs)
241
+
242
+ torch._foreach_sub_(params, new_vars.get_update())
243
+
244
+ # create new vars since we are at a new point, that means grad, update and loss will be None
245
+ new_vars = Vars(params=new_vars.params, closure=new_vars.closure,
246
+ model=new_vars.model, current_step=new_vars.current_step + 1)
247
+
248
+ # step
249
+ new_vars = m.step(new_vars)
250
+
251
+ # final parameter update
252
+ if (not new_vars.skip_update):
253
+ if new_vars.last_module_lrs is not None:
254
+ torch._foreach_mul_(new_vars.get_update(), new_vars.last_module_lrs)
255
+
256
+ torch._foreach_sub_(params, new_vars.get_update())
257
+
258
+ # if last module, update is applied so return new vars
259
+ if params_before_steps is None:
260
+ new_vars.stop = True
261
+ new_vars.skip_update = True
262
+ return new_vars
263
+
264
+ # otherwise use parameter difference as update
265
+ vars.update = list(torch._foreach_sub(params_before_steps, params))
266
+ for p, bef in zip(params, params_before_steps):
267
+ p.set_(bef) # pyright:ignore[reportArgumentType]
268
+ return vars
269
+
270
+ class Multistep(Module):
271
+ def __init__(self, module: Chainable, steps: int):
272
+ defaults = dict(steps=steps)
273
+ super().__init__(defaults)
274
+ self.set_child('module', module)
275
+
276
+ @torch.no_grad
277
+ def step(self, vars):
278
+ return _sequential_step(self, vars, sequential=False)
279
+
280
+ class Sequential(Module):
281
+ def __init__(self, modules: Iterable[Chainable], steps: int):
282
+ defaults = dict(steps=steps)
283
+ super().__init__(defaults)
284
+ self.set_children_sequence(modules)
285
+
286
+ @torch.no_grad
287
+ def step(self, vars):
288
+ return _sequential_step(self, vars, sequential=True)
289
+
290
+
291
+ class GradientAccumulation(Module):
292
+ """gradient accumulation"""
293
+ def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
294
+ defaults = dict(n=n, mean=mean, stop=stop)
295
+ super().__init__(defaults)
296
+ self.set_child('modules', modules)
297
+
298
+
299
+ @torch.no_grad
300
+ def step(self, vars):
301
+ accumulator = self.get_state('accumulator', params=vars.params)
302
+ settings = self.settings[vars.params[0]]
303
+ n = settings['n']; mean = settings['mean']; stop = settings['stop']
304
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
305
+
306
+ # add update to accumulator
307
+ torch._foreach_add_(accumulator, vars.get_update())
308
+
309
+ # step with accumulated updates
310
+ if step % n == 0:
311
+ if mean:
312
+ torch._foreach_div_(accumulator, n)
313
+
314
+ vars.update = [a.clone() for a in accumulator]
315
+ vars = self.children['modules'].step(vars)
316
+
317
+ # zero accumulator
318
+ torch._foreach_zero_(accumulator)
319
+
320
+ else:
321
+ # prevent update
322
+ if stop:
323
+ vars.stop=True
324
+ vars.skip_update=True
325
+
326
+ return vars
327
+
328
+
329
+ class Dropout(Transform):
330
+ def __init__(self, p: float = 0.5, graft: bool=False, target: Target = 'update'):
331
+ defaults = dict(p=p, graft=graft)
332
+ super().__init__(defaults, uses_grad=False, target=target)
333
+
334
+ @torch.no_grad
335
+ def transform(self, tensors, params, grads, vars):
336
+ tensors = TensorList(tensors)
337
+ p = self.get_settings('p', params=params, cls=NumberList)
338
+ graft = self.settings[params[0]]['graft']
339
+
340
+ if graft:
341
+ target_norm = tensors.global_vector_norm()
342
+ tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
343
+ return tensors.mul_(target_norm / tensors.global_vector_norm()) # graft
344
+
345
+ return tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
346
+
347
+ class WeightDropout(Module):
348
+ """Applies dropout directly to weights."""
349
+ def __init__(self, p: float = 0.5, graft: bool = True):
350
+ defaults = dict(p=p, graft=graft)
351
+ super().__init__(defaults)
352
+
353
+ @torch.no_grad
354
+ def step(self, vars):
355
+ closure = vars.closure
356
+ if closure is None: raise RuntimeError('WeightDropout requires closure')
357
+ params = TensorList(vars.params)
358
+ p = self.get_settings('p', params=params)
359
+ mask = params.rademacher_like(p).add_(1).div_(2).as_bool()
360
+
361
+ @torch.no_grad
362
+ def dropout_closure(backward=True):
363
+ orig_params = params.clone()
364
+ params.mul_(mask)
365
+ if backward:
366
+ with torch.enable_grad(): loss = closure()
367
+ else:
368
+ loss = closure(False)
369
+ params.copy_(orig_params)
370
+ return loss
371
+
372
+ vars.closure = dropout_closure
373
+ return vars
374
+
375
+ class NoiseSign(Transform):
376
+ """uses random vector with update sign"""
377
+ def __init__(self, distribution:Distributions = 'normal', alpha = 1):
378
+ defaults = dict(distribution=distribution, alpha=alpha)
379
+ super().__init__(defaults, uses_grad=False)
380
+
381
+ @torch.no_grad
382
+ def transform(self, tensors, params, grads, vars):
383
+ alpha = self.get_settings('alpha', params=params)
384
+ distribution = self.settings[params[0]]['distribution']
385
+ return TensorList(tensors).sample_like(alpha, distribution).copysign_(tensors)
386
+
387
+
388
+ class NegateOnLossIncrease(Module):
389
+ def __init__(self, backtrack=True):
390
+ defaults = dict(backtrack=backtrack)
391
+ super().__init__(defaults=defaults)
392
+
393
+ @torch.no_grad
394
+ def step(self, vars):
395
+ closure = vars.closure
396
+ if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
397
+ backtrack = self.settings[vars.params[0]]['backtrack']
398
+
399
+ update = vars.get_update()
400
+ f_0 = vars.get_loss(backward=False)
401
+
402
+ torch._foreach_sub_(vars.params, update)
403
+ f_1 = closure(False)
404
+
405
+ if f_1 <= f_0:
406
+ if vars.is_last and vars.last_module_lrs is None:
407
+ vars.stop = True
408
+ vars.skip_update = True
409
+ return vars
410
+
411
+ torch._foreach_add_(vars.params, update)
412
+ return vars
413
+
414
+ torch._foreach_add_(vars.params, update)
415
+ if backtrack:
416
+ torch._foreach_neg_(vars.update)
417
+ else:
418
+ torch._foreach_zero_(vars.update)
419
+ return vars
@@ -0,0 +1,137 @@
1
+ #pyright: reportIncompatibleMethodOverride=false
2
+ """"""
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Iterable, Sequence
5
+ from operator import itemgetter
6
+ from typing import Any
7
+
8
+ import torch
9
+
10
+ from ...core import Chainable, Module, Target, Vars, maybe_chain
11
+ from ...utils import TensorList, tensorlist
12
+
13
+
14
+ class MultiOperation(Module, ABC):
15
+ """Base class for operations that use operands. This is an abstract class, subclass it and override `transform` method to use it."""
16
+ def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
17
+ super().__init__(defaults=defaults)
18
+
19
+ self.operands = {}
20
+ for k,v in operands.items():
21
+
22
+ if isinstance(v, (Module, Sequence)):
23
+ self.set_child(k, v)
24
+ self.operands[k] = self.children[k]
25
+ else:
26
+ self.operands[k] = v
27
+
28
+ if not self.children:
29
+ raise ValueError('At least one operand must be a module')
30
+
31
+ @abstractmethod
32
+ def transform(self, vars: Vars, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
33
+ """applies the operation to operands"""
34
+ raise NotImplementedError
35
+
36
+ @torch.no_grad
37
+ def step(self, vars: Vars) -> Vars:
38
+ # pass cloned update to all module operands
39
+ processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()
40
+
41
+ for k,v in self.operands.items():
42
+ if k in self.children:
43
+ v: Module
44
+ updated_vars = v.step(vars.clone(clone_update=True))
45
+ processed_operands[k] = updated_vars.get_update()
46
+ vars.update_attrs_from_clone_(updated_vars) # update loss, grad, etc if this module calculated them
47
+
48
+ transformed = self.transform(vars, **processed_operands)
49
+ vars.update = transformed
50
+ return vars
51
+
52
+
53
+
54
+ class SubModules(MultiOperation):
55
+ def __init__(self, input: Chainable | float, other: Chainable | float, alpha: float = 1):
56
+ defaults = dict(alpha=alpha)
57
+ super().__init__(defaults, input=input, other=other)
58
+
59
+ @torch.no_grad
60
+ def transform(self, vars: Vars, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
61
+ alpha = self.settings[vars.params[0]]['alpha']
62
+
63
+ if isinstance(input, (int,float)):
64
+ assert isinstance(other, list)
65
+ return input - TensorList(other).mul_(alpha)
66
+
67
+ if isinstance(other, (int, float)): torch._foreach_sub_(input, other * alpha)
68
+ else: torch._foreach_sub_(input, other, alpha=alpha)
69
+ return input
70
+
71
+ class DivModules(MultiOperation):
72
+ def __init__(self, input: Chainable | float, other: Chainable | float):
73
+ defaults = {}
74
+ super().__init__(defaults, input=input, other=other)
75
+
76
+ @torch.no_grad
77
+ def transform(self, vars: Vars, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
78
+ if isinstance(input, (int,float)):
79
+ assert isinstance(other, list)
80
+ return input / TensorList(other)
81
+
82
+ torch._foreach_div_(input, other)
83
+ return input
84
+
85
+ class PowModules(MultiOperation):
86
+ def __init__(self, input: Chainable | float, exponent: Chainable | float):
87
+ defaults = {}
88
+ super().__init__(defaults, input=input, exponent=exponent)
89
+
90
+ @torch.no_grad
91
+ def transform(self, vars: Vars, input: float | list[torch.Tensor], exponent: float | list[torch.Tensor]) -> list[torch.Tensor]:
92
+ if isinstance(input, (int,float)):
93
+ assert isinstance(exponent, list)
94
+ return input ** TensorList(exponent)
95
+
96
+ torch._foreach_div_(input, exponent)
97
+ return input
98
+
99
+ class LerpModules(MultiOperation):
100
+ def __init__(self, input: Chainable, end: Chainable, weight: float):
101
+ defaults = dict(weight=weight)
102
+ super().__init__(defaults, input=input, end=end)
103
+
104
+ @torch.no_grad
105
+ def transform(self, vars: Vars, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
106
+ torch._foreach_lerp_(input, end, weight=self.settings[vars.params[0]]['weight'])
107
+ return input
108
+
109
+ class ClipModules(MultiOperation):
110
+ def __init__(self, input: Chainable, min: float | Chainable | None = None, max: float | Chainable | None = None):
111
+ defaults = {}
112
+ super().__init__(defaults, input=input, min=min, max=max)
113
+
114
+ @torch.no_grad
115
+ def transform(self, vars: Vars, input: list[torch.Tensor], min: float | list[torch.Tensor], max: float | list[torch.Tensor]) -> list[torch.Tensor]:
116
+ return TensorList(input).clamp_(min=min, max=max)
117
+
118
+
119
+ class GraftModules(MultiOperation):
120
+ def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6, strength:float=1):
121
+ defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
122
+ super().__init__(defaults, direction=direction, magnitude=magnitude)
123
+
124
+ @torch.no_grad
125
+ def transform(self, vars, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
126
+ tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.settings[vars.params[0]])
127
+ return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
128
+
129
+
130
+ class Where(MultiOperation):
131
+ def __init__(self, condition: Chainable, input: Chainable | float, other: Chainable | float):
132
+ super().__init__({}, condition=condition, input=input, other=other)
133
+
134
+ @torch.no_grad
135
+ def transform(self, vars, condition: list[torch.Tensor], input: list[torch.Tensor] | float, other: list[torch.Tensor] | float):
136
+ return tensorlist.where(TensorList(condition).as_bool(), input, other)
137
+