torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,316 @@
1
+ from collections import deque
2
+ from collections.abc import Iterable
3
+ from operator import itemgetter
4
+ from typing import Literal
5
+
6
+ import torch
7
+
8
+ from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
9
+ from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
10
+
11
+
12
+ class Previous(TensorwiseTransform):
13
+ """Maintains an update from n steps back, for example if n=1, returns previous update"""
14
+ def __init__(self, n=1, target: Target = 'update'):
15
+ defaults = dict(n=n)
16
+ super().__init__(uses_grad=False, defaults=defaults, target=target)
17
+
18
+
19
+ @torch.no_grad
20
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
21
+ n = setting['n']
22
+
23
+ if 'history' not in state:
24
+ state['history'] = deque(maxlen=n+1)
25
+
26
+ state['history'].append(tensor)
27
+
28
+ return state['history'][0]
29
+
30
+
31
+ class LastDifference(Transform):
32
+ """Outputs difference between past two updates."""
33
+ def __init__(self,target: Target = 'update'):
34
+ super().__init__({}, target=target)
35
+
36
+ @torch.no_grad
37
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
38
+ prev_tensors = unpack_states(states, tensors, 'prev_tensors') # initialized to 0
39
+ difference = torch._foreach_sub(tensors, prev_tensors)
40
+ for p, c in zip(prev_tensors, tensors): p.set_(c)
41
+ return difference
42
+
43
+ class LastGradDifference(Module):
44
+ """Outputs difference between past two gradients."""
45
+ def __init__(self):
46
+ super().__init__({})
47
+
48
+ @torch.no_grad
49
+ def step(self, var):
50
+ grad = var.get_grad()
51
+ prev_grad = self.get_state(var.params, 'prev_grad') # initialized to 0
52
+ difference = torch._foreach_sub(grad, prev_grad)
53
+ for p, c in zip(prev_grad, grad): p.copy_(c)
54
+ var.update = list(difference)
55
+ return var
56
+
57
+ class LastParamDifference(Module):
58
+ """Outputs difference between past two parameters, which is the effective previous update."""
59
+ def __init__(self):
60
+ super().__init__({})
61
+
62
+ @torch.no_grad
63
+ def step(self, var):
64
+ params = var.params
65
+ prev_params = self.get_state(var.params, 'prev_params') # initialized to 0
66
+ difference = torch._foreach_sub(params, prev_params)
67
+ for p, c in zip(prev_params, params): p.copy_(c)
68
+ var.update = list(difference)
69
+ return var
70
+
71
+
72
+
73
+ class LastProduct(Transform):
74
+ """Outputs difference between past two updates."""
75
+ def __init__(self,target: Target = 'update'):
76
+ super().__init__({}, uses_grad=False, target=target)
77
+
78
+ @torch.no_grad
79
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
80
+ prev = unpack_states(states, tensors, 'prev', init=torch.ones_like) # initialized to 1 for prod
81
+ prod = torch._foreach_mul(tensors, prev)
82
+ for p, c in zip(prev, tensors): p.set_(c)
83
+ return prod
84
+
85
+ class LastRatio(Transform):
86
+ """Outputs ratio between past two updates, the numerator is determined by :code:`numerator` argument."""
87
+ def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', target: Target = 'update'):
88
+ defaults = dict(numerator=numerator)
89
+ super().__init__(defaults, uses_grad=False, target=target)
90
+
91
+ @torch.no_grad
92
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
93
+ prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
94
+ numerator = settings[0]['numerator']
95
+ if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
96
+ else: ratio = torch._foreach_div(prev, tensors)
97
+ for p, c in zip(prev, tensors): p.set_(c)
98
+ return ratio
99
+
100
+ class LastAbsoluteRatio(Transform):
101
+ """Outputs ratio between absolute values of past two updates the numerator is determined by :code:`numerator` argument."""
102
+ def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8, target: Target = 'update'):
103
+ defaults = dict(numerator=numerator, eps=eps)
104
+ super().__init__(defaults, uses_grad=False, target=target)
105
+
106
+ @torch.no_grad
107
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
108
+ prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
109
+ numerator = settings[0]['numerator']
110
+ eps = NumberList(s['eps'] for s in settings)
111
+
112
+ torch._foreach_abs_(tensors)
113
+ torch._foreach_clamp_min_(prev, eps)
114
+
115
+ if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
116
+ else: ratio = torch._foreach_div(prev, tensors)
117
+ for p, c in zip(prev, tensors): p.set_(c)
118
+ return ratio
119
+
120
+ class GradSign(Transform):
121
+ """Copies gradient sign to update."""
122
+ def __init__(self, target: Target = 'update'):
123
+ super().__init__({}, uses_grad=True, target=target)
124
+
125
+ @torch.no_grad
126
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
127
+ assert grads is not None
128
+ return [t.copysign_(g) for t,g in zip(tensors, grads)]
129
+
130
+ class UpdateSign(Transform):
131
+ """Outputs gradient with sign copied from the update."""
132
+ def __init__(self, target: Target = 'update'):
133
+ super().__init__({}, uses_grad=True, target=target)
134
+
135
+ @torch.no_grad
136
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
137
+ assert grads is not None
138
+ return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place
139
+
140
+ class GraftToGrad(Transform):
141
+ """Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient."""
142
+ def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
143
+ defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
144
+ super().__init__(defaults, uses_grad=True, target=target)
145
+
146
+ @torch.no_grad
147
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
148
+ assert grads is not None
149
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
150
+ return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)
151
+
152
+ class GraftGradToUpdate(Transform):
153
+ """Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update."""
154
+ def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
155
+ defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
156
+ super().__init__(defaults, uses_grad=True, target=target)
157
+
158
+ @torch.no_grad
159
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
160
+ assert grads is not None
161
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
162
+ return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)
163
+
164
+
165
+ class GraftToParams(Transform):
166
+ """Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than :code:`eps`."""
167
+ def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-4, target: Target = 'update'):
168
+ defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
169
+ super().__init__(defaults, uses_grad=False, target=target)
170
+
171
+ @torch.no_grad
172
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
173
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
174
+ return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)
175
+
176
+ class Relative(Transform):
177
+ """Multiplies update by absolute parameter values to make it relative to their magnitude, :code:`min_value` is minimum allowed value to avoid getting stuck at 0."""
178
+ def __init__(self, min_value:float = 1e-4, target: Target = 'update'):
179
+ defaults = dict(min_value=min_value)
180
+ super().__init__(defaults, uses_grad=False, target=target)
181
+
182
+ @torch.no_grad
183
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
184
+ mul = TensorList(params).abs().clamp_([s['min_value'] for s in settings])
185
+ torch._foreach_mul_(tensors, mul)
186
+ return tensors
187
+
188
+ class FillLoss(Module):
189
+ """Outputs tensors filled with loss value times :code:`alpha`"""
190
+ def __init__(self, alpha: float = 1, backward: bool = True):
191
+ defaults = dict(alpha=alpha, backward=backward)
192
+ super().__init__(defaults)
193
+
194
+ @torch.no_grad
195
+ def step(self, var):
196
+ alpha = self.get_settings(var.params, 'alpha')
197
+ loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
198
+ var.update = [torch.full_like(p, loss*a) for p,a in zip(var.params, alpha)]
199
+ return var
200
+
201
+ class MulByLoss(Module):
202
+ """Multiplies update by loss times :code:`alpha`"""
203
+ def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
204
+ defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
205
+ super().__init__(defaults)
206
+
207
+ @torch.no_grad
208
+ def step(self, var):
209
+ alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
210
+ loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
211
+ mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
212
+ torch._foreach_mul_(var.update, mul)
213
+ return var
214
+
215
+ class DivByLoss(Module):
216
+ """Divides update by loss times :code:`alpha`"""
217
+ def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
218
+ defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
219
+ super().__init__(defaults)
220
+
221
+ @torch.no_grad
222
+ def step(self, var):
223
+ alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
224
+ loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
225
+ mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
226
+ torch._foreach_div_(var.update, mul)
227
+ return var
228
+
229
+
230
+ class NoiseSign(Transform):
231
+ """Outputs random tensors with sign copied from the update."""
232
+ def __init__(self, distribution:Distributions = 'normal', alpha = 1):
233
+ defaults = dict(distribution=distribution, alpha=alpha)
234
+ super().__init__(defaults, uses_grad=False)
235
+
236
+ @torch.no_grad
237
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
238
+ alpha = [s['alpha'] for s in settings]
239
+ distribution = self.settings[params[0]]['distribution']
240
+ return TensorList(tensors).sample_like(alpha, distribution).copysign_(tensors)
241
+
242
+ class HpuEstimate(Transform):
243
+ """returns ``y/||s||``, where ``y`` is difference between current and previous update (gradient), ``s`` is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update."""
244
+ def __init__(self):
245
+ defaults = dict()
246
+ super().__init__(defaults, uses_grad=False)
247
+
248
+ def reset_for_online(self):
249
+ super().reset_for_online()
250
+ self.clear_state_keys('prev_params', 'prev_update')
251
+
252
+ @torch.no_grad
253
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
254
+ prev_params, prev_update = self.get_state(params, 'prev_params', 'prev_update') # initialized to 0
255
+ s = torch._foreach_sub(params, prev_params)
256
+ y = torch._foreach_sub(tensors, prev_update)
257
+ for p, c in zip(prev_params, params): p.copy_(c)
258
+ for p, c in zip(prev_update, tensors): p.copy_(c)
259
+ torch._foreach_div_(y, torch.linalg.norm(torch.cat([t.ravel() for t in s])).clip(min=1e-8)) # pylint:disable=not-callable
260
+ self.store(params, ['s', 'y'], [s, y])
261
+
262
+ @torch.no_grad
263
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
264
+ return [self.state[p]['y'] for p in params]
265
+
266
+ class RandomHvp(Module):
267
+ """Returns a hessian-vector product with a random vector"""
268
+
269
+ def __init__(
270
+ self,
271
+ n_samples: int = 1,
272
+ distribution: Distributions = "normal",
273
+ update_freq: int = 1,
274
+ hvp_method: Literal["autograd", "forward", "central"] = "autograd",
275
+ h=1e-3,
276
+ ):
277
+ defaults = dict(n_samples=n_samples, distribution=distribution, hvp_method=hvp_method, h=h, update_freq=update_freq)
278
+ super().__init__(defaults)
279
+
280
+ @torch.no_grad
281
+ def step(self, var):
282
+ params = TensorList(var.params)
283
+ settings = self.settings[params[0]]
284
+ n_samples = settings['n_samples']
285
+ distribution = settings['distribution']
286
+ hvp_method = settings['hvp_method']
287
+ h = settings['h']
288
+ update_freq = settings['update_freq']
289
+
290
+ step = self.global_state.get('step', 0)
291
+ self.global_state['step'] = step + 1
292
+
293
+ D = None
294
+ if step % update_freq == 0:
295
+
296
+ rgrad = None
297
+ for i in range(n_samples):
298
+ u = params.sample_like(distribution=distribution)
299
+
300
+ Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
301
+ h=h, normalize=True, retain_grad=i < n_samples-1)
302
+
303
+ if D is None: D = Hvp
304
+ else: torch._foreach_add_(D, Hvp)
305
+
306
+ if n_samples > 1: torch._foreach_div_(D, n_samples)
307
+ if update_freq != 1:
308
+ assert D is not None
309
+ D_buf = self.get_state(params, "D", cls=TensorList)
310
+ D_buf.set_(D)
311
+
312
+ if D is None:
313
+ D = self.get_state(params, "D", cls=TensorList)
314
+
315
+ var.update = list(D)
316
+ return var
@@ -0,0 +1,158 @@
1
+ from collections.abc import Iterable
2
+
3
+ import torch
4
+
5
+ from ...core import Chainable, Module, Var
6
+ from ...utils import TensorList
7
+
8
+ def _sequential_step(self: Module, var: Var, sequential: bool):
9
+ params = var.params
10
+ steps = self.settings[params[0]]['steps']
11
+
12
+ if sequential: modules = self.get_children_sequence() * steps
13
+ else: modules = [self.children['module']] * steps
14
+
15
+ if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
16
+
17
+ # store original params unless this is last module and can update params directly
18
+ params_before_steps = None if (var.is_last and var.last_module_lrs is None) else [p.clone() for p in params]
19
+
20
+ # first step - pass var as usual
21
+ var = modules[0].step(var)
22
+ new_var = var
23
+
24
+ # subsequent steps - update parameters and create new var
25
+ if len(modules) > 1:
26
+ for m in modules[1:]:
27
+
28
+ # update params
29
+ if (not new_var.skip_update):
30
+ if new_var.last_module_lrs is not None:
31
+ torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
32
+
33
+ torch._foreach_sub_(params, new_var.get_update())
34
+
35
+ # create new var since we are at a new point, that means grad, update and loss will be None
36
+ new_var = Var(params=new_var.params, closure=new_var.closure,
37
+ model=new_var.model, current_step=new_var.current_step + 1)
38
+
39
+ # step
40
+ new_var = m.step(new_var)
41
+
42
+ # final parameter update
43
+ if (not new_var.skip_update):
44
+ if new_var.last_module_lrs is not None:
45
+ torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
46
+
47
+ torch._foreach_sub_(params, new_var.get_update())
48
+
49
+ # if last module, update is applied so return new var
50
+ if params_before_steps is None:
51
+ new_var.stop = True
52
+ new_var.skip_update = True
53
+ return new_var
54
+
55
+ # otherwise use parameter difference as update
56
+ var.update = list(torch._foreach_sub(params_before_steps, params))
57
+ for p, bef in zip(params, params_before_steps):
58
+ p.set_(bef) # pyright:ignore[reportArgumentType]
59
+ return var
60
+
61
+ class Multistep(Module):
62
+ """Performs :code:`steps` inner steps with :code:`module` per each step.
63
+
64
+ The update is taken to be the parameter difference between parameters before and after the inner loop."""
65
+ def __init__(self, module: Chainable, steps: int):
66
+ defaults = dict(steps=steps)
67
+ super().__init__(defaults)
68
+ self.set_child('module', module)
69
+
70
+ @torch.no_grad
71
+ def step(self, var):
72
+ return _sequential_step(self, var, sequential=False)
73
+
74
+ class Sequential(Module):
75
+ """On each step, this sequentially steps with :code:`modules` :code:`steps` times.
76
+
77
+ The update is taken to be the parameter difference between parameters before and after the inner loop."""
78
+ def __init__(self, modules: Iterable[Chainable], steps: int=1):
79
+ defaults = dict(steps=steps)
80
+ super().__init__(defaults)
81
+ self.set_children_sequence(modules)
82
+
83
+ @torch.no_grad
84
+ def step(self, var):
85
+ return _sequential_step(self, var, sequential=True)
86
+
87
+
88
+ class NegateOnLossIncrease(Module):
89
+ """Uses an extra forward pass to evaluate loss at :code:`parameters+update`,
90
+ if loss is larger than at :code:`parameters`,
91
+ the update is set to 0 if :code:`backtrack=False` and to :code:`-update` otherwise"""
92
+ def __init__(self, backtrack=False):
93
+ defaults = dict(backtrack=backtrack)
94
+ super().__init__(defaults=defaults)
95
+
96
+ @torch.no_grad
97
+ def step(self, var):
98
+ closure = var.closure
99
+ if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
100
+ backtrack = self.settings[var.params[0]]['backtrack']
101
+
102
+ update = var.get_update()
103
+ f_0 = var.get_loss(backward=False)
104
+
105
+ torch._foreach_sub_(var.params, update)
106
+ f_1 = closure(False)
107
+
108
+ if f_1 <= f_0:
109
+ if var.is_last and var.last_module_lrs is None:
110
+ var.stop = True
111
+ var.skip_update = True
112
+ return var
113
+
114
+ torch._foreach_add_(var.params, update)
115
+ return var
116
+
117
+ torch._foreach_add_(var.params, update)
118
+ if backtrack:
119
+ torch._foreach_neg_(var.update)
120
+ else:
121
+ torch._foreach_zero_(var.update)
122
+ return var
123
+
124
+
125
+ class Online(Module):
126
+ """Allows certain modules to be used for mini-batch optimization."""
127
+ def __init__(self, module: Chainable,):
128
+ super().__init__()
129
+
130
+ self.set_child('module', module)
131
+
132
+ @torch.no_grad
133
+ def step(self, var):
134
+ closure = var.closure
135
+ if closure is None: raise ValueError("Closure must be passed for Online")
136
+ step = self.global_state.get('step', 0) + 1
137
+ self.global_state['step'] = step
138
+ params = TensorList(var.params)
139
+ p_cur = params.clone()
140
+ p_prev = self.get_state(params, 'p_prev', cls=TensorList)
141
+ module = self.children['module']
142
+
143
+ if step == 1:
144
+ var = module.step(var.clone(clone_update=False))
145
+
146
+ p_prev.copy_(params)
147
+ return var
148
+
149
+ # restore previous params
150
+ var_prev = Var(params=params, closure=closure, model=var.model, current_step=var.current_step)
151
+ params.set_(p_prev)
152
+ module.reset_for_online()
153
+ module.update(var_prev)
154
+
155
+ # restore current params
156
+ params.set_(p_cur)
157
+ p_prev.copy_(params)
158
+ return module.step(var.clone(clone_update=False))
@@ -0,0 +1,171 @@
1
+ from collections import deque
2
+ from collections.abc import Iterable
3
+ from operator import itemgetter
4
+ from typing import Literal
5
+
6
+ import torch
7
+
8
+ from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
9
+ from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
10
+
11
+
12
+ class Dropout(Transform):
13
+ """Applies dropout to the update.
14
+
15
+ For each weight the update to that weight has :code:`p` probability to be set to 0.
16
+ This can be used to implement gradient dropout or update dropout depending on placement.
17
+
18
+ Args:
19
+ p (float, optional): probability that update for a weight is replaced with 0. Defaults to 0.5.
20
+ graft (bool, optional):
21
+ if True, update after dropout is rescaled to have the same norm as before dropout. Defaults to False.
22
+ target (Target, optional): what to set on var, refer to documentation. Defaults to 'update'.
23
+
24
+
25
+ Examples:
26
+ Gradient dropout.
27
+
28
+ .. code-block:: python
29
+
30
+ opt = tz.Modular(
31
+ model.parameters(),
32
+ tz.m.Dropout(0.5),
33
+ tz.m.Adam(),
34
+ tz.m.LR(1e-3)
35
+ )
36
+
37
+ Update dropout.
38
+
39
+ .. code-block:: python
40
+
41
+ opt = tz.Modular(
42
+ model.parameters(),
43
+ tz.m.Adam(),
44
+ tz.m.Dropout(0.5),
45
+ tz.m.LR(1e-3)
46
+ )
47
+
48
+ """
49
+ def __init__(self, p: float = 0.5, graft: bool=False, target: Target = 'update'):
50
+ defaults = dict(p=p, graft=graft)
51
+ super().__init__(defaults, uses_grad=False, target=target)
52
+
53
+ @torch.no_grad
54
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
55
+ tensors = TensorList(tensors)
56
+ p = NumberList(s['p'] for s in settings)
57
+ graft = settings[0]['graft']
58
+
59
+ if graft:
60
+ target_norm = tensors.global_vector_norm()
61
+ tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
62
+ return tensors.mul_(target_norm / tensors.global_vector_norm()) # graft
63
+
64
+ return tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
65
+
66
+ def _bernoulli_like(tensor, p = 0.5, generator = None):
67
+ """p is probability of a 1, other values will be 0."""
68
+ return torch.bernoulli(torch.full_like(tensor, p), generator = generator)
69
+
70
+ class WeightDropout(Module):
71
+ """
72
+ Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.
73
+
74
+ Dropout can be disabled for a parameter by setting :code:`use_dropout=False` in corresponding parameter group.
75
+
76
+ Args:
77
+ p (float, optional): probability that any weight is replaced with 0. Defaults to 0.5.
78
+ graft (bool, optional):
79
+ if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.
80
+ """
81
+ def __init__(self, p: float = 0.5, graft: bool = True):
82
+ defaults = dict(p=p, graft=graft, use_dropout=True)
83
+ super().__init__(defaults)
84
+
85
+ @torch.no_grad
86
+ def step(self, var):
87
+ closure = var.closure
88
+ if closure is None: raise RuntimeError('WeightDropout requires closure')
89
+ params = TensorList(var.params)
90
+ p = NumberList(self.settings[p]['p'] for p in params)
91
+
92
+ # create masks
93
+ mask = []
94
+ for p, m in zip(params, mask):
95
+ prob = self.settings[p]['p']
96
+ use_dropout = self.settings[p]['use_dropout']
97
+ if use_dropout: mask.append(_bernoulli_like(p, prob))
98
+ else: mask.append(torch.ones_like(p))
99
+
100
+ @torch.no_grad
101
+ def dropout_closure(backward=True):
102
+ orig_params = params.clone()
103
+ params.mul_(mask)
104
+ if backward:
105
+ with torch.enable_grad(): loss = closure()
106
+ else:
107
+ loss = closure(False)
108
+ params.copy_(orig_params)
109
+ return loss
110
+
111
+ var.closure = dropout_closure
112
+ return var
113
+
114
+
115
+ class PerturbWeights(Module):
116
+ """
117
+ Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.
118
+
119
+ Can be disabled for a parameter by setting :code:`perturb=False` in corresponding parameter group.
120
+
121
+ Args:
122
+ alpha (float, optional): multiplier for perturbation magnitude. Defaults to 0.1.
123
+ relative (bool, optional): whether to multiply perturbation by mean absolute value of the parameter. Defaults to True.
124
+ graft (bool, optional):
125
+ if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.
126
+ """
127
+ def __init__(self, alpha: float = 0.1, relative:bool=True, distribution:Distributions = 'normal'):
128
+ defaults = dict(alpha=alpha, relative=relative, distribution=distribution, perturb=True)
129
+ super().__init__(defaults)
130
+
131
+ @torch.no_grad
132
+ def step(self, var):
133
+ closure = var.closure
134
+ if closure is None: raise RuntimeError('WeightDropout requires closure')
135
+ params = TensorList(var.params)
136
+
137
+ # create perturbations
138
+ perts = []
139
+ for p in params:
140
+ settings = self.settings[p]
141
+ if not settings['perturb']:
142
+ perts.append(torch.zeros_like(p))
143
+ continue
144
+
145
+ alpha = settings['alpha']
146
+ if settings['relative']:
147
+ alpha *= p.abs().mean()
148
+
149
+ distribution = self.settings[p]['distribution'].lower()
150
+ if distribution in ('normal', 'gaussian'):
151
+ perts.append(torch.randn_like(p).mul_(alpha))
152
+ elif distribution == 'uniform':
153
+ perts.append(torch.empty_like(p).uniform_(-alpha,alpha))
154
+ elif distribution == 'sphere':
155
+ r = torch.randn_like(p)
156
+ perts.append((r * alpha) / torch.linalg.vector_norm(r)) # pylint:disable=not-callable
157
+ else:
158
+ raise ValueError(distribution)
159
+
160
+ @torch.no_grad
161
+ def perturbed_closure(backward=True):
162
+ params.add_(perts)
163
+ if backward:
164
+ with torch.enable_grad(): loss = closure()
165
+ else:
166
+ loss = closure(False)
167
+ params.sub_(perts)
168
+ return loss
169
+
170
+ var.closure = perturbed_closure
171
+ return var
@@ -45,7 +45,35 @@ def _split(
45
45
  return var
46
46
 
47
47
  class Split(Module):
48
- """Apply `true` modules to all parameters filtered by `filter`, apply `false` modules to all other parameters."""
48
+ """Apply `true` modules to all parameters filtered by `filter`, apply `false` modules to all other parameters.
49
+
50
+ Args:
51
+ filter (Callable[[torch.Tensor], bool]): a function that takes in a parameter tensor and returns a boolean value.
52
+ true (Chainable | None): modules that are applied to tensors where :code:`filter` returned True.
53
+ false (Chainable | None): modules that are applied to tensors where :code:`filter` returned False.
54
+
55
+ Examples:
56
+ standard Muon with Adam fallback
57
+
58
+ .. code-block:: python
59
+
60
+ opt = tz.Modular(
61
+ model.head.parameters(),
62
+ tz.m.Split(
63
+ # apply muon only to 2D+ parameters
64
+ filter = lambda t: t.ndim >= 2,
65
+ true = [
66
+ tz.m.HeavyBall(),
67
+ tz.m.Orthogonalize(),
68
+ tz.m.LR(1e-2),
69
+ ],
70
+ false = tz.m.Adam()
71
+ ),
72
+ tz.m.LR(1e-2)
73
+ )
74
+
75
+
76
+ """
49
77
  def __init__(self, filter: Callable[[torch.Tensor], bool], true: Chainable | None, false: Chainable | None):
50
78
  defaults = dict(filter=filter)
51
79
  super().__init__(defaults)