torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,194 @@
1
+ from collections.abc import Iterable
2
+
3
+ import torch
4
+
5
+ from ...core import Chainable, Module, Var
6
+ from ...utils import TensorList
7
+
8
+ def _sequential_step(self: Module, var: Var, sequential: bool):
9
+ params = var.params
10
+ steps = self.settings[params[0]]['steps']
11
+
12
+ if sequential: modules = self.get_children_sequence() * steps
13
+ else: modules = [self.children['module']] * steps
14
+
15
+ if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
16
+
17
+ # store original params unless this is last module and can update params directly
18
+ params_before_steps = None if (var.is_last and var.last_module_lrs is None) else [p.clone() for p in params]
19
+
20
+ # first step - pass var as usual
21
+ var = modules[0].step(var)
22
+ new_var = var
23
+
24
+ # subsequent steps - update parameters and create new var
25
+ if len(modules) > 1:
26
+ for m in modules[1:]:
27
+
28
+ # update params
29
+ if (not new_var.skip_update):
30
+ if new_var.last_module_lrs is not None:
31
+ torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
32
+
33
+ torch._foreach_sub_(params, new_var.get_update())
34
+
35
+ # create new var since we are at a new point, that means grad, update and loss will be None
36
+ new_var = Var(params=new_var.params, closure=new_var.closure,
37
+ model=new_var.model, current_step=new_var.current_step + 1)
38
+
39
+ # step
40
+ new_var = m.step(new_var)
41
+
42
+ # final parameter update
43
+ if (not new_var.skip_update):
44
+ if new_var.last_module_lrs is not None:
45
+ torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
46
+
47
+ torch._foreach_sub_(params, new_var.get_update())
48
+
49
+ # if last module, update is applied so return new var
50
+ if params_before_steps is None:
51
+ new_var.stop = True
52
+ new_var.skip_update = True
53
+ return new_var
54
+
55
+ # otherwise use parameter difference as update
56
+ var.update = list(torch._foreach_sub(params_before_steps, params))
57
+ for p, bef in zip(params, params_before_steps):
58
+ p.set_(bef) # pyright:ignore[reportArgumentType]
59
+ return var
60
+
61
+ class Multistep(Module):
62
+ """Performs :code:`steps` inner steps with :code:`module` per each step.
63
+
64
+ The update is taken to be the parameter difference between parameters before and after the inner loop."""
65
+ def __init__(self, module: Chainable, steps: int):
66
+ defaults = dict(steps=steps)
67
+ super().__init__(defaults)
68
+ self.set_child('module', module)
69
+
70
+ @torch.no_grad
71
+ def step(self, var):
72
+ return _sequential_step(self, var, sequential=False)
73
+
74
+ class Sequential(Module):
75
+ """On each step, this sequentially steps with :code:`modules` :code:`steps` times.
76
+
77
+ The update is taken to be the parameter difference between parameters before and after the inner loop."""
78
+ def __init__(self, modules: Iterable[Chainable], steps: int=1):
79
+ defaults = dict(steps=steps)
80
+ super().__init__(defaults)
81
+ self.set_children_sequence(modules)
82
+
83
+ @torch.no_grad
84
+ def step(self, var):
85
+ return _sequential_step(self, var, sequential=True)
86
+
87
+
88
+ class NegateOnLossIncrease(Module):
89
+ """Uses an extra forward pass to evaluate loss at :code:`parameters+update`,
90
+ if loss is larger than at :code:`parameters`,
91
+ the update is set to 0 if :code:`backtrack=False` and to :code:`-update` otherwise"""
92
+ def __init__(self, backtrack=False):
93
+ defaults = dict(backtrack=backtrack)
94
+ super().__init__(defaults=defaults)
95
+
96
+ @torch.no_grad
97
+ def step(self, var):
98
+ closure = var.closure
99
+ if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
100
+ backtrack = self.defaults['backtrack']
101
+
102
+ update = var.get_update()
103
+ f_0 = var.get_loss(backward=False)
104
+
105
+ torch._foreach_sub_(var.params, update)
106
+ f_1 = closure(False)
107
+
108
+ if f_1 <= f_0:
109
+ if var.is_last and var.last_module_lrs is None:
110
+ var.stop = True
111
+ var.skip_update = True
112
+ return var
113
+
114
+ torch._foreach_add_(var.params, update)
115
+ return var
116
+
117
+ torch._foreach_add_(var.params, update)
118
+ if backtrack:
119
+ torch._foreach_neg_(var.update)
120
+ else:
121
+ torch._foreach_zero_(var.update)
122
+ return var
123
+
124
+
125
+ class Online(Module):
126
+ """Allows certain modules to be used for mini-batch optimization.
127
+
128
+ Examples:
129
+
130
+ Online L-BFGS with Backtracking line search
131
+ ```python
132
+ opt = tz.Modular(
133
+ model.parameters(),
134
+ tz.m.Online(tz.m.LBFGS()),
135
+ tz.m.Backtracking()
136
+ )
137
+ ```
138
+
139
+ Online L-BFGS trust region
140
+ ```python
141
+ opt = tz.Modular(
142
+ model.parameters(),
143
+ tz.m.TrustCG(tz.m.Online(tz.m.LBFGS()))
144
+ )
145
+ ```
146
+
147
+ """
148
+ def __init__(self, *modules: Module,):
149
+ super().__init__()
150
+
151
+ self.set_child('module', modules)
152
+
153
+ @torch.no_grad
154
+ def update(self, var):
155
+ closure = var.closure
156
+ if closure is None: raise ValueError("Closure must be passed for Online")
157
+
158
+ step = self.global_state.get('step', 0) + 1
159
+ self.global_state['step'] = step
160
+
161
+ params = TensorList(var.params)
162
+ p_cur = params.clone()
163
+ p_prev = self.get_state(params, 'p_prev', cls=TensorList)
164
+
165
+ module = self.children['module']
166
+ var_c = var.clone(clone_update=False)
167
+
168
+ # on 1st step just step and store previous params
169
+ if step == 1:
170
+ p_prev.copy_(params)
171
+
172
+ module.update(var_c)
173
+ var.update_attrs_from_clone_(var_c)
174
+ return
175
+
176
+ # restore previous params and update
177
+ var_prev = Var(params=params, closure=closure, model=var.model, current_step=var.current_step)
178
+ params.set_(p_prev)
179
+ module.reset_for_online()
180
+ module.update(var_prev)
181
+
182
+ # restore current params and update
183
+ params.set_(p_cur)
184
+ p_prev.copy_(params)
185
+ module.update(var_c)
186
+ var.update_attrs_from_clone_(var_c)
187
+
188
+ @torch.no_grad
189
+ def apply(self, var):
190
+ module = self.children['module']
191
+ return module.apply(var.clone(clone_update=False))
192
+
193
+ def get_H(self, var):
194
+ return self.children['module'].get_H(var)
@@ -0,0 +1,167 @@
1
+ import torch
2
+
3
+ from ...core import Chainable, Module, Target, Transform
4
+ from ...core.reformulation import Reformulation
5
+ from ...utils import Distributions, NumberList, TensorList
6
+
7
+
8
+ class Dropout(Transform):
9
+ """Applies dropout to the update.
10
+
11
+ For each weight the update to that weight has :code:`p` probability to be set to 0.
12
+ This can be used to implement gradient dropout or update dropout depending on placement.
13
+
14
+ Args:
15
+ p (float, optional): probability that update for a weight is replaced with 0. Defaults to 0.5.
16
+ graft (bool, optional):
17
+ if True, update after dropout is rescaled to have the same norm as before dropout. Defaults to False.
18
+ target (Target, optional): what to set on var, refer to documentation. Defaults to 'update'.
19
+
20
+
21
+ Examples:
22
+ Gradient dropout.
23
+
24
+ .. code-block:: python
25
+
26
+ opt = tz.Modular(
27
+ model.parameters(),
28
+ tz.m.Dropout(0.5),
29
+ tz.m.Adam(),
30
+ tz.m.LR(1e-3)
31
+ )
32
+
33
+ Update dropout.
34
+
35
+ .. code-block:: python
36
+
37
+ opt = tz.Modular(
38
+ model.parameters(),
39
+ tz.m.Adam(),
40
+ tz.m.Dropout(0.5),
41
+ tz.m.LR(1e-3)
42
+ )
43
+
44
+ """
45
+ def __init__(self, p: float = 0.5, graft: bool=False, target: Target = 'update'):
46
+ defaults = dict(p=p, graft=graft)
47
+ super().__init__(defaults, uses_grad=False, target=target)
48
+
49
+ @torch.no_grad
50
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
51
+ tensors = TensorList(tensors)
52
+ p = NumberList(s['p'] for s in settings)
53
+ graft = settings[0]['graft']
54
+
55
+ if graft:
56
+ target_norm = tensors.global_vector_norm()
57
+ tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
58
+ return tensors.mul_(target_norm / tensors.global_vector_norm()) # graft
59
+
60
+ return tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
61
+
62
+ def _bernoulli_like(tensor, p = 0.5, generator = None):
63
+ """p is probability of a 1, other values will be 0."""
64
+ return torch.bernoulli(torch.full_like(tensor, p), generator = generator)
65
+
66
+ class WeightDropout(Module):
67
+ """
68
+ Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.
69
+
70
+ Dropout can be disabled for a parameter by setting :code:`use_dropout=False` in corresponding parameter group.
71
+
72
+ Args:
73
+ p (float, optional): probability that any weight is replaced with 0. Defaults to 0.5.
74
+ graft (bool, optional):
75
+ if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.
76
+ """
77
+ def __init__(self, p: float = 0.5, graft: bool = True):
78
+ defaults = dict(p=p, graft=graft, use_dropout=True)
79
+ super().__init__(defaults)
80
+
81
+ @torch.no_grad
82
+ def step(self, var):
83
+ closure = var.closure
84
+ if closure is None: raise RuntimeError('WeightDropout requires closure')
85
+ params = TensorList(var.params)
86
+ p = NumberList(self.settings[p]['p'] for p in params)
87
+
88
+ # create masks
89
+ mask = []
90
+ for p, m in zip(params, mask):
91
+ prob = self.settings[p]['p']
92
+ use_dropout = self.settings[p]['use_dropout']
93
+ if use_dropout: mask.append(_bernoulli_like(p, prob))
94
+ else: mask.append(torch.ones_like(p))
95
+
96
+ @torch.no_grad
97
+ def dropout_closure(backward=True):
98
+ orig_params = params.clone()
99
+ params.mul_(mask)
100
+ if backward:
101
+ with torch.enable_grad(): loss = closure()
102
+ else:
103
+ loss = closure(False)
104
+ params.copy_(orig_params)
105
+ return loss
106
+
107
+ var.closure = dropout_closure
108
+ return var
109
+
110
+
111
+ class PerturbWeights(Module):
112
+ """
113
+ Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.
114
+
115
+ Can be disabled for a parameter by setting :code:`perturb=False` in corresponding parameter group.
116
+
117
+ Args:
118
+ alpha (float, optional): multiplier for perturbation magnitude. Defaults to 0.1.
119
+ relative (bool, optional): whether to multiply perturbation by mean absolute value of the parameter. Defaults to True.
120
+ distribution (bool, optional):
121
+ distribution of the random perturbation. Defaults to False.
122
+ """
123
+ def __init__(self, alpha: float = 0.1, relative:bool=True, distribution:Distributions = 'normal'):
124
+ defaults = dict(alpha=alpha, relative=relative, distribution=distribution, perturb=True)
125
+ super().__init__(defaults)
126
+
127
+ @torch.no_grad
128
+ def step(self, var):
129
+ closure = var.closure
130
+ if closure is None: raise RuntimeError('WeightDropout requires closure')
131
+ params = TensorList(var.params)
132
+
133
+ # create perturbations
134
+ perts = []
135
+ for p in params:
136
+ settings = self.settings[p]
137
+ if not settings['perturb']:
138
+ perts.append(torch.zeros_like(p))
139
+ continue
140
+
141
+ alpha = settings['alpha']
142
+ if settings['relative']:
143
+ alpha *= p.abs().mean()
144
+
145
+ distribution = self.settings[p]['distribution'].lower()
146
+ if distribution in ('normal', 'gaussian'):
147
+ perts.append(torch.randn_like(p).mul_(alpha))
148
+ elif distribution == 'uniform':
149
+ perts.append(torch.empty_like(p).uniform_(-alpha,alpha))
150
+ elif distribution == 'sphere':
151
+ r = torch.randn_like(p)
152
+ perts.append((r * alpha) / torch.linalg.vector_norm(r)) # pylint:disable=not-callable
153
+ else:
154
+ raise ValueError(distribution)
155
+
156
+ @torch.no_grad
157
+ def perturbed_closure(backward=True):
158
+ params.add_(perts)
159
+ if backward:
160
+ with torch.enable_grad(): loss = closure()
161
+ else:
162
+ loss = closure(False)
163
+ params.sub_(perts)
164
+ return loss
165
+
166
+ var.closure = perturbed_closure
167
+ return var
@@ -0,0 +1,123 @@
1
+ import warnings
2
+ from collections.abc import Callable, Sequence, Iterable
3
+ from typing import cast
4
+
5
+ import torch
6
+
7
+ from ...core import Chainable, Module, Var
8
+
9
+
10
+ def _split(
11
+ module: Module,
12
+ idxs,
13
+ params,
14
+ var: Var,
15
+ ):
16
+ split_params = [p for i,p in enumerate(params) if i in idxs]
17
+
18
+ split_grad = None
19
+ if var.grad is not None:
20
+ split_grad = [g for i,g in enumerate(var.grad) if i in idxs]
21
+
22
+ split_update = None
23
+ if var.update is not None:
24
+ split_update = [u for i,u in enumerate(var.update) if i in idxs]
25
+
26
+ split_var = var.clone(clone_update=False, parent=var)
27
+ split_var.params = split_params
28
+ split_var.grad = split_grad
29
+ split_var.update = split_update
30
+
31
+ split_var = module.step(split_var)
32
+
33
+ # those should be set due to var being parent
34
+ if split_var.grad is not None:
35
+ assert var.grad is not None
36
+
37
+ if split_var.loss is not None:
38
+ assert var.loss is not None
39
+
40
+ if split_var.update is not None:
41
+
42
+ # make sure update is set, it will be filled with ``true`` and ``false`` tensors
43
+ if var.update is None:
44
+ if var.grad is None: var.update = [cast(torch.Tensor, None) for _ in var.params]
45
+ else: var.update = [g.clone() for g in var.grad]
46
+
47
+ # set all tensors from this split
48
+ for idx, u in zip(idxs, split_var.update):
49
+ var.update[idx] = u
50
+
51
+ return var
52
+
53
+ _SingleFilter = Callable[[torch.Tensor], bool] | torch.Tensor | Iterable[torch.Tensor] | torch.nn.Module | Iterable[torch.nn.Module]
54
+ Filter = _SingleFilter | Iterable[_SingleFilter]
55
+
56
+ def _make_filter(filter: Filter):
57
+ if callable(filter): return filter
58
+ if isinstance(filter, torch.Tensor):
59
+ return lambda x: x is filter
60
+ if isinstance(filter, torch.nn.Module):
61
+ return _make_filter(filter.parameters())
62
+
63
+ # iterable
64
+ filters = [_make_filter(f) for f in filter]
65
+ return lambda x: any(f(x) for f in filters)
66
+
67
+ class Split(Module):
68
+ """Apply ``true`` modules to all parameters filtered by ``filter``, apply ``false`` modules to all other parameters.
69
+
70
+ Args:
71
+ filter (Filter, bool]):
72
+ a filter that selects tensors to be optimized by ``true``.
73
+ - tensor or iterable of tensors (e.g. ``encoder.parameters()``).
74
+ - function that takes in tensor and outputs a bool (e.g. ``lambda x: x.ndim >= 2``).
75
+ - a sequence of above (acts as "or", so returns true if any of them is true).
76
+
77
+ true (Chainable | None): modules that are applied to tensors where ``filter`` is ``True``.
78
+ false (Chainable | None): modules that are applied to tensors where ``filter`` is ``False``.
79
+
80
+ ### Examples:
81
+
82
+ Muon with Adam fallback using same hyperparams as https://github.com/KellerJordan/Muon
83
+
84
+ ```python
85
+ opt = tz.Modular(
86
+ model.parameters(),
87
+ tz.m.NAG(0.95),
88
+ tz.m.Split(
89
+ lambda p: p.ndim >= 2,
90
+ true = tz.m.Orthogonalize(),
91
+ false = [tz.m.Adam(0.9, 0.95), tz.m.Mul(1/66)],
92
+ ),
93
+ tz.m.LR(1e-2),
94
+ )
95
+ ```
96
+ """
97
+ def __init__(self, filter: Filter, true: Chainable | None, false: Chainable | None):
98
+ defaults = dict(filter=filter)
99
+ super().__init__(defaults)
100
+
101
+ if true is not None: self.set_child('true', true)
102
+ if false is not None: self.set_child('false', false)
103
+
104
+ def step(self, var):
105
+
106
+ params = var.params
107
+ filter = _make_filter(self.settings[params[0]]['filter'])
108
+
109
+ true_idxs = []
110
+ false_idxs = []
111
+ for i,p in enumerate(params):
112
+ if filter(p): true_idxs.append(i)
113
+ else: false_idxs.append(i)
114
+
115
+ if 'true' in self.children and len(true_idxs) > 0:
116
+ true = self.children['true']
117
+ var = _split(true, idxs=true_idxs, params=params, var=var)
118
+
119
+ if 'false' in self.children and len(false_idxs) > 0:
120
+ false = self.children['false']
121
+ var = _split(false, idxs=false_idxs, params=params, var=var)
122
+
123
+ return var
@@ -7,7 +7,28 @@ from ...core import Chainable, Module
7
7
 
8
8
 
9
9
  class Alternate(Module):
10
- """alternate between stepping with `modules`"""
10
+ """Alternates between stepping with :code:`modules`.
11
+
12
+ That is, first step is performed with 1st module, second step with second module, etc.
13
+
14
+ Args:
15
+ steps (int | Iterable[int], optional): number of steps to perform with each module. Defaults to 1.
16
+
17
+ Examples:
18
+ Alternate between Adam, SignSGD and RMSprop
19
+
20
+ .. code-block:: python
21
+
22
+ opt = tz.Modular(
23
+ model.parameters(),
24
+ tz.m.Alternate(
25
+ tz.m.Adam(),
26
+ [tz.m.SignSGD(), tz.m.Mul(0.5)],
27
+ tz.m.RMSprop(),
28
+ ),
29
+ tz.m.LR(1e-3),
30
+ )
31
+ """
11
32
  LOOP = True
12
33
  def __init__(self, *modules: Chainable, steps: int | Iterable[int] = 1):
13
34
  if isinstance(steps, Iterable):
@@ -32,7 +53,7 @@ class Alternate(Module):
32
53
  var = module.step(var.clone(clone_update=False))
33
54
 
34
55
  # number of steps until next module
35
- steps = self.settings[var.params[0]]['steps']
56
+ steps = self.defaults['steps']
36
57
  if isinstance(steps, int): steps = [steps]*len(self.children)
37
58
 
38
59
  if 'steps_to_next' not in self.global_state:
@@ -54,14 +75,34 @@ class Alternate(Module):
54
75
  return var
55
76
 
56
77
  class Switch(Alternate):
57
- """switch to next module after some steps"""
78
+ """After :code:`steps` steps switches to the next module.
79
+
80
+ Args:
81
+ steps (int | Iterable[int]): Number of steps to perform with each module.
82
+
83
+ Examples:
84
+ Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.
85
+
86
+ .. code-block:: python
87
+
88
+ opt = tz.Modular(
89
+ model.parameters(),
90
+ tz.m.Switch(
91
+ [tz.m.Adam(), tz.m.LR(1e-3)],
92
+ [tz.m.LBFGS(), tz.m.Backtracking()],
93
+ [tz.m.NewtonCG(maxiter=20), tz.m.Backtracking()],
94
+ steps = (1000, 2000)
95
+ )
96
+ )
97
+ """
98
+
58
99
  LOOP = False
59
100
  def __init__(self, *modules: Chainable, steps: int | Iterable[int]):
60
101
 
61
102
  if isinstance(steps, Iterable):
62
103
  steps = list(steps)
63
104
  if len(steps) != len(modules) - 1:
64
- raise ValueError(f"steps must be the same length as modules, got {len(modules) = }, {len(steps) = }")
105
+ raise ValueError(f"steps must be the same length as modules minus 1, got {len(modules) = }, {len(steps) = }")
65
106
 
66
107
  steps.append(1)
67
108
 
@@ -6,9 +6,5 @@ from .cautious import (
6
6
  ScaleModulesByCosineSimilarity,
7
7
  UpdateGradientSignConsistency,
8
8
  )
9
- from .ema import EMA, Debias, Debias2, EMASquared, SqrtEMASquared, CenteredEMASquared, CenteredSqrtEMASquared
10
- from .experimental import CoordinateMomentum
11
- # from .matrix_momentum import MatrixMomentum
12
9
 
13
- from .momentum import NAG, HeavyBall
14
- from .matrix_momentum import MatrixMomentum, AdaptiveMatrixMomentum
10
+ from .momentum import NAG, HeavyBall, EMA
@@ -10,7 +10,7 @@ from ...utils import tolist
10
10
 
11
11
 
12
12
  class Averaging(TensorwiseTransform):
13
- """Average of past :code:`history_size` updates.
13
+ """Average of past ``history_size`` updates.
14
14
 
15
15
  Args:
16
16
  history_size (int): Number of past updates to average
@@ -21,8 +21,8 @@ class Averaging(TensorwiseTransform):
21
21
  super().__init__(uses_grad=False, defaults=defaults, target=target)
22
22
 
23
23
  @torch.no_grad
24
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
25
- history_size = settings['history_size']
24
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
25
+ history_size = setting['history_size']
26
26
  if 'history' not in state:
27
27
  state['history'] = deque(maxlen=history_size)
28
28
  state['average'] = torch.zeros_like(tensor)
@@ -35,7 +35,7 @@ class Averaging(TensorwiseTransform):
35
35
  return average / len(history)
36
36
 
37
37
  class WeightedAveraging(TensorwiseTransform):
38
- """Weighted average of past :code:`len(weights)` updates.
38
+ """Weighted average of past ``len(weights)`` updates.
39
39
 
40
40
  Args:
41
41
  weights (Sequence[float]): a sequence of weights from oldest to newest.
@@ -46,8 +46,8 @@ class WeightedAveraging(TensorwiseTransform):
46
46
  super().__init__(uses_grad=False, defaults=defaults, target=target)
47
47
 
48
48
  @torch.no_grad
49
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
50
- weights = settings['weights']
49
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
50
+ weights = setting['weights']
51
51
 
52
52
  if 'history' not in state:
53
53
  state['history'] = deque(maxlen=len(weights))
@@ -69,7 +69,7 @@ class WeightedAveraging(TensorwiseTransform):
69
69
 
70
70
 
71
71
  class MedianAveraging(TensorwiseTransform):
72
- """Median of past :code:`history_size` updates.
72
+ """Median of past ``history_size`` updates.
73
73
 
74
74
  Args:
75
75
  history_size (int): Number of past updates to average
@@ -80,8 +80,8 @@ class MedianAveraging(TensorwiseTransform):
80
80
  super().__init__(uses_grad=False, defaults=defaults, target=target)
81
81
 
82
82
  @torch.no_grad
83
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
84
- history_size = settings['history_size']
83
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
84
+ history_size = setting['history_size']
85
85
 
86
86
  if 'history' not in state:
87
87
  state['history'] = deque(maxlen=history_size)