torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,7 +1,10 @@
1
+ import math
2
+
3
+ from typing import Literal
1
4
  import torch
2
5
 
3
- from ...core import Module
4
- from ...utils import TensorList, NumberList
6
+ from ...core import Modular, Module, Var, Chainable
7
+ from ...utils import NumberList, TensorList
5
8
 
6
9
 
7
10
  class EscapeAnnealing(Module):
@@ -42,19 +45,18 @@ class EscapeAnnealing(Module):
42
45
  if n_bad >= n_tol:
43
46
  for i in range(1, max_iter+1):
44
47
  alpha = max_region * (i / max_iter)
45
- pert = params.sample_like(distribution='sphere').mul_(alpha)
48
+ pert = params.sphere_like(radius=alpha)
46
49
 
47
50
  params.add_(pert)
48
51
  f_star = closure(False)
49
52
 
50
- if f_star < f_0-1e-10:
53
+ if math.isfinite(f_star) and f_star < f_0-1e-12:
51
54
  var.update = None
52
55
  var.stop = True
53
56
  var.skip_update = True
54
57
  return var
55
58
 
56
- else:
57
- params.sub_(pert)
59
+ params.sub_(pert)
58
60
 
59
61
  self.global_state['n_bad'] = 0
60
- return var
62
+ return var
@@ -3,46 +3,112 @@ import torch
3
3
  from ...core import Chainable, Module
4
4
 
5
5
 
6
+ # class GradientAccumulation(Module):
7
+ # """Uses :code:`n` steps to accumulate gradients, after :code:`n` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
8
+
9
+ # Accumulating gradients for :code:`n` steps is equivalent to increasing batch size by :code:`n`. Increasing the batch size
10
+ # is more computationally efficient, but sometimes it is not feasible due to memory constraints.
11
+
12
+ # .. note::
13
+ # Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.
14
+
15
+ # Args:
16
+ # modules (Chainable): modules that perform a step every :code:`n` steps using the accumulated gradients.
17
+ # n (int): number of gradients to accumulate.
18
+ # mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
19
+ # stop (bool, optional):
20
+ # this module prevents next modules from stepping unless :code:`n` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.
21
+
22
+ # Examples:
23
+ # Adam with gradients accumulated for 16 batches.
24
+
25
+ # .. code-block:: python
26
+
27
+ # opt = tz.Modular(
28
+ # model.parameters(),
29
+ # tz.m.GradientAccumulation(
30
+ # [tz.m.Adam(), tz.m.LR(1e-2)],
31
+ # n=16
32
+ # )
33
+ # )
34
+
35
+ # """
36
+ # def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
37
+ # defaults = dict(n=n, mean=mean, stop=stop)
38
+ # super().__init__(defaults)
39
+ # self.set_child('modules', modules)
40
+
41
+
42
+ # @torch.no_grad
43
+ # def step(self, var):
44
+ # accumulator = self.get_state(var.params, 'accumulator')
45
+ # settings = self.defaults
46
+ # n = settings['n']; mean = settings['mean']; stop = settings['stop']
47
+ # step = self.global_state['step'] = self.global_state.get('step', 0) + 1
48
+
49
+ # # add update to accumulator
50
+ # torch._foreach_add_(accumulator, var.get_update())
51
+
52
+ # # step with accumulated updates
53
+ # if step % n == 0:
54
+ # if mean:
55
+ # torch._foreach_div_(accumulator, n)
56
+
57
+ # var.update = [a.clone() for a in accumulator]
58
+ # var = self.children['modules'].step(var)
59
+
60
+ # # zero accumulator
61
+ # torch._foreach_zero_(accumulator)
62
+
63
+ # else:
64
+ # # prevent update
65
+ # if stop:
66
+ # var.update = None
67
+ # var.stop=True
68
+ # var.skip_update=True
69
+
70
+ # return var
71
+
72
+
73
+
74
+
6
75
  class GradientAccumulation(Module):
7
- """Uses :code:`n` steps to accumulate gradients, after :code:`n` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
76
+ """Uses ``n`` steps to accumulate gradients, after ``n`` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
8
77
 
9
- Accumulating gradients for :code:`n` steps is equivalent to increasing batch size by :code:`n`. Increasing the batch size
78
+ Accumulating gradients for ``n`` steps is equivalent to increasing batch size by ``n``. Increasing the batch size
10
79
  is more computationally efficient, but sometimes it is not feasible due to memory constraints.
11
80
 
12
- .. note::
81
+ Note:
13
82
  Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.
14
83
 
15
84
  Args:
16
- modules (Chainable): modules that perform a step every :code:`n` steps using the accumulated gradients.
17
85
  n (int): number of gradients to accumulate.
18
86
  mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
19
87
  stop (bool, optional):
20
- this module prevents next modules from stepping unless :code:`n` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.
21
-
22
- Examples:
23
- Adam with gradients accumulated for 16 batches.
88
+ this module prevents next modules from stepping unless ``n`` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.
24
89
 
25
- .. code-block:: python
90
+ ## Examples:
26
91
 
27
- opt = tz.Modular(
28
- model.parameters(),
29
- tz.m.GradientAccumulation(
30
- modules=[tz.m.Adam(), tz.m.LR(1e-2)],
31
- n=16
32
- )
33
- )
92
+ Adam with gradients accumulated for 16 batches.
34
93
 
94
+ ```python
95
+ opt = tz.Modular(
96
+ model.parameters(),
97
+ tz.m.GradientAccumulation(),
98
+ tz.m.Adam(),
99
+ tz.m.LR(1e-2),
100
+ )
101
+ ```
35
102
  """
36
- def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
103
+ def __init__(self, n: int, mean=True, stop=True):
37
104
  defaults = dict(n=n, mean=mean, stop=stop)
38
105
  super().__init__(defaults)
39
- self.set_child('modules', modules)
40
106
 
41
107
 
42
108
  @torch.no_grad
43
109
  def step(self, var):
44
110
  accumulator = self.get_state(var.params, 'accumulator')
45
- settings = self.settings[var.params[0]]
111
+ settings = self.defaults
46
112
  n = settings['n']; mean = settings['mean']; stop = settings['stop']
47
113
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
48
114
 
@@ -54,15 +120,15 @@ class GradientAccumulation(Module):
54
120
  if mean:
55
121
  torch._foreach_div_(accumulator, n)
56
122
 
57
- var.update = [a.clone() for a in accumulator]
58
- var = self.children['modules'].step(var)
123
+ var.update = accumulator
59
124
 
60
125
  # zero accumulator
61
- torch._foreach_zero_(accumulator)
126
+ self.clear_state_keys('accumulator')
62
127
 
63
128
  else:
64
129
  # prevent update
65
130
  if stop:
131
+ var.update = None
66
132
  var.stop=True
67
133
  var.skip_update=True
68
134
 
@@ -0,0 +1,59 @@
1
+ from collections.abc import Callable
2
+ from abc import ABC, abstractmethod
3
+ import torch
4
+ from ...core import Module
5
+ from ...core import Chainable
6
+
7
+ class HomotopyBase(Module):
8
+ def __init__(self, defaults: dict | None = None):
9
+ super().__init__(defaults)
10
+
11
+ @abstractmethod
12
+ def loss_transform(self, loss: torch.Tensor) -> torch.Tensor:
13
+ """transform the loss"""
14
+
15
+ @torch.no_grad
16
+ def step(self, var):
17
+ if var.loss is not None:
18
+ var.loss = self.loss_transform(var.loss)
19
+
20
+ closure = var.closure
21
+ if closure is None: raise RuntimeError("SquareHomotopy requires closure")
22
+
23
+ def homotopy_closure(backward=True):
24
+ if backward:
25
+ with torch.enable_grad():
26
+ loss = self.loss_transform(closure(False))
27
+ grad = torch.autograd.grad(loss, var.params, allow_unused=True)
28
+ for p,g in zip(var.params, grad):
29
+ p.grad = g
30
+ else:
31
+ loss = self.loss_transform(closure(False))
32
+
33
+ return loss
34
+
35
+ var.closure = homotopy_closure
36
+ return var
37
+
38
+ class SquareHomotopy(HomotopyBase):
39
+ def __init__(self): super().__init__()
40
+ def loss_transform(self, loss): return loss.square().copysign(loss)
41
+
42
+ class SqrtHomotopy(HomotopyBase):
43
+ def __init__(self): super().__init__()
44
+ def loss_transform(self, loss): return (loss+1e-12).sqrt()
45
+
46
+ class ExpHomotopy(HomotopyBase):
47
+ def __init__(self): super().__init__()
48
+ def loss_transform(self, loss): return loss.exp()
49
+
50
+ class LogHomotopy(HomotopyBase):
51
+ def __init__(self): super().__init__()
52
+ def loss_transform(self, loss): return (loss+1e-12).log()
53
+
54
+ class LambdaHomotopy(HomotopyBase):
55
+ def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
56
+ defaults = dict(fn=fn)
57
+ super().__init__(defaults)
58
+
59
+ def loss_transform(self, loss): return self.defaults['fn'](loss)
@@ -1,12 +1,22 @@
1
1
  from collections import deque
2
- from collections.abc import Iterable
2
+ from collections.abc import Iterable, Sequence
3
+ from functools import partial
3
4
  from operator import itemgetter
4
5
  from typing import Literal
5
6
 
6
7
  import torch
7
8
 
8
9
  from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
9
- from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
10
+ from ...utils import (
11
+ Distributions,
12
+ Metrics,
13
+ NumberList,
14
+ TensorList,
15
+ set_storage_,
16
+ tofloat,
17
+ unpack_dicts,
18
+ unpack_states,
19
+ )
10
20
 
11
21
 
12
22
  class Previous(TensorwiseTransform):
@@ -139,7 +149,7 @@ class UpdateSign(Transform):
139
149
 
140
150
  class GraftToGrad(Transform):
141
151
  """Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient."""
142
- def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
152
+ def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6, target: Target = 'update'):
143
153
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
144
154
  super().__init__(defaults, uses_grad=True, target=target)
145
155
 
@@ -151,7 +161,7 @@ class GraftToGrad(Transform):
151
161
 
152
162
  class GraftGradToUpdate(Transform):
153
163
  """Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update."""
154
- def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
164
+ def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6, target: Target = 'update'):
155
165
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
156
166
  super().__init__(defaults, uses_grad=True, target=target)
157
167
 
@@ -164,7 +174,7 @@ class GraftGradToUpdate(Transform):
164
174
 
165
175
  class GraftToParams(Transform):
166
176
  """Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than :code:`eps`."""
167
- def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-4, target: Target = 'update'):
177
+ def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-4, target: Target = 'update'):
168
178
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
169
179
  super().__init__(defaults, uses_grad=False, target=target)
170
180
 
@@ -194,7 +204,7 @@ class FillLoss(Module):
194
204
  @torch.no_grad
195
205
  def step(self, var):
196
206
  alpha = self.get_settings(var.params, 'alpha')
197
- loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
207
+ loss = var.get_loss(backward=self.defaults['backward'])
198
208
  var.update = [torch.full_like(p, loss*a) for p,a in zip(var.params, alpha)]
199
209
  return var
200
210
 
@@ -207,7 +217,7 @@ class MulByLoss(Module):
207
217
  @torch.no_grad
208
218
  def step(self, var):
209
219
  alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
210
- loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
220
+ loss = var.get_loss(backward=self.defaults['backward'])
211
221
  mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
212
222
  torch._foreach_mul_(var.update, mul)
213
223
  return var
@@ -221,7 +231,7 @@ class DivByLoss(Module):
221
231
  @torch.no_grad
222
232
  def step(self, var):
223
233
  alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
224
- loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
234
+ loss = var.get_loss(backward=self.defaults['backward'])
225
235
  mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
226
236
  torch._foreach_div_(var.update, mul)
227
237
  return var
@@ -229,15 +239,14 @@ class DivByLoss(Module):
229
239
 
230
240
  class NoiseSign(Transform):
231
241
  """Outputs random tensors with sign copied from the update."""
232
- def __init__(self, distribution:Distributions = 'normal', alpha = 1):
233
- defaults = dict(distribution=distribution, alpha=alpha)
242
+ def __init__(self, distribution:Distributions = 'normal', variance:float | None = None):
243
+ defaults = dict(distribution=distribution, variance=variance)
234
244
  super().__init__(defaults, uses_grad=False)
235
245
 
236
246
  @torch.no_grad
237
247
  def apply_tensors(self, tensors, params, grads, loss, states, settings):
238
- alpha = [s['alpha'] for s in settings]
239
- distribution = self.settings[params[0]]['distribution']
240
- return TensorList(tensors).sample_like(alpha, distribution).copysign_(tensors)
248
+ variance = unpack_dicts(settings, 'variance')
249
+ return TensorList(tensors).sample_like(settings[0]['distribution'], variance=variance).copysign_(tensors)
241
250
 
242
251
  class HpuEstimate(Transform):
243
252
  """returns ``y/||s||``, where ``y`` is difference between current and previous update (gradient), ``s`` is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update."""
@@ -257,7 +266,7 @@ class HpuEstimate(Transform):
257
266
  for p, c in zip(prev_params, params): p.copy_(c)
258
267
  for p, c in zip(prev_update, tensors): p.copy_(c)
259
268
  torch._foreach_div_(y, torch.linalg.norm(torch.cat([t.ravel() for t in s])).clip(min=1e-8)) # pylint:disable=not-callable
260
- self.store(params, ['s', 'y'], [s, y])
269
+ self.store(params, 'y', y)
261
270
 
262
271
  @torch.no_grad
263
272
  def apply_tensors(self, tensors, params, grads, loss, states, settings):
@@ -295,7 +304,7 @@ class RandomHvp(Module):
295
304
 
296
305
  rgrad = None
297
306
  for i in range(n_samples):
298
- u = params.sample_like(distribution=distribution)
307
+ u = params.sample_like(distribution=distribution, variance=1)
299
308
 
300
309
  Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
301
310
  h=h, normalize=True, retain_grad=i < n_samples-1)
@@ -314,3 +323,61 @@ class RandomHvp(Module):
314
323
 
315
324
  var.update = list(D)
316
325
  return var
326
+
327
+ @torch.no_grad
328
+ def _load_best_parameters(params: Sequence[torch.Tensor], best_params: Sequence[torch.Tensor]):
329
+ for p_cur, p_best in zip(params, best_params):
330
+ set_storage_(p_cur, p_best)
331
+
332
+ class SaveBest(Module):
333
+ """Saves best parameters found so far, ones that have lowest loss. Put this as the last module.
334
+
335
+ Adds the following attrs:
336
+
337
+ - ``best_params`` - a list of tensors with best parameters.
338
+ - ``best_loss`` - loss value with ``best_params``.
339
+ - ``load_best_parameters`` - a function that sets parameters to the best parameters./
340
+
341
+ ## Examples
342
+ ```python
343
+ def rosenbrock(x, y):
344
+ return (1 - x)**2 + (100 * (y - x**2))**2
345
+
346
+ xy = torch.tensor((-1.1, 2.5), requires_grad=True)
347
+ opt = tz.Modular(
348
+ [xy],
349
+ tz.m.NAG(0.999),
350
+ tz.m.LR(1e-6),
351
+ tz.m.SaveBest()
352
+ )
353
+
354
+ # optimize for 1000 steps
355
+ for i in range(1000):
356
+ loss = rosenbrock(*xy)
357
+ opt.zero_grad()
358
+ loss.backward()
359
+ opt.step(loss=loss) # SaveBest needs closure or loss
360
+
361
+ # NAG overshot, but we saved the best params
362
+ print(f'{rosenbrock(*xy) = }') # >> 3.6583
363
+ print(f"{opt.attrs['best_loss'] = }") # >> 0.000627
364
+
365
+ # load best parameters
366
+ opt.attrs['load_best_params']()
367
+ print(f'{rosenbrock(*xy) = }') # >> 0.000627
368
+ """
369
+ def __init__(self):
370
+ super().__init__()
371
+
372
+ @torch.no_grad
373
+ def step(self, var):
374
+ loss = tofloat(var.get_loss(False))
375
+ lowest_loss = self.global_state.get('lowest_loss', float("inf"))
376
+
377
+ if loss < lowest_loss:
378
+ self.global_state['lowest_loss'] = loss
379
+ best_params = var.attrs['best_params'] = [p.clone() for p in var.params]
380
+ var.attrs['best_loss'] = loss
381
+ var.attrs['load_best_params'] = partial(_load_best_parameters, params=var.params, best_params=best_params)
382
+
383
+ return var
@@ -97,7 +97,7 @@ class NegateOnLossIncrease(Module):
97
97
  def step(self, var):
98
98
  closure = var.closure
99
99
  if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
100
- backtrack = self.settings[var.params[0]]['backtrack']
100
+ backtrack = self.defaults['backtrack']
101
101
 
102
102
  update = var.get_update()
103
103
  f_0 = var.get_loss(backward=False)
@@ -123,36 +123,72 @@ class NegateOnLossIncrease(Module):
123
123
 
124
124
 
125
125
  class Online(Module):
126
- """Allows certain modules to be used for mini-batch optimization."""
127
- def __init__(self, module: Chainable,):
126
+ """Allows certain modules to be used for mini-batch optimization.
127
+
128
+ Examples:
129
+
130
+ Online L-BFGS with Backtracking line search
131
+ ```python
132
+ opt = tz.Modular(
133
+ model.parameters(),
134
+ tz.m.Online(tz.m.LBFGS()),
135
+ tz.m.Backtracking()
136
+ )
137
+ ```
138
+
139
+ Online L-BFGS trust region
140
+ ```python
141
+ opt = tz.Modular(
142
+ model.parameters(),
143
+ tz.m.TrustCG(tz.m.Online(tz.m.LBFGS()))
144
+ )
145
+ ```
146
+
147
+ """
148
+ def __init__(self, *modules: Module,):
128
149
  super().__init__()
129
150
 
130
- self.set_child('module', module)
151
+ self.set_child('module', modules)
131
152
 
132
153
  @torch.no_grad
133
- def step(self, var):
154
+ def update(self, var):
134
155
  closure = var.closure
135
156
  if closure is None: raise ValueError("Closure must be passed for Online")
157
+
136
158
  step = self.global_state.get('step', 0) + 1
137
159
  self.global_state['step'] = step
160
+
138
161
  params = TensorList(var.params)
139
162
  p_cur = params.clone()
140
163
  p_prev = self.get_state(params, 'p_prev', cls=TensorList)
164
+
141
165
  module = self.children['module']
166
+ var_c = var.clone(clone_update=False)
142
167
 
168
+ # on 1st step just step and store previous params
143
169
  if step == 1:
144
- var = module.step(var.clone(clone_update=False))
145
-
146
170
  p_prev.copy_(params)
147
- return var
148
171
 
149
- # restore previous params
172
+ module.update(var_c)
173
+ var.update_attrs_from_clone_(var_c)
174
+ return
175
+
176
+ # restore previous params and update
150
177
  var_prev = Var(params=params, closure=closure, model=var.model, current_step=var.current_step)
151
178
  params.set_(p_prev)
152
179
  module.reset_for_online()
153
180
  module.update(var_prev)
154
181
 
155
- # restore current params
182
+ # restore current params and update
156
183
  params.set_(p_cur)
157
184
  p_prev.copy_(params)
158
- return module.step(var.clone(clone_update=False))
185
+ module.update(var_c)
186
+ var.update_attrs_from_clone_(var_c)
187
+
188
+ @torch.no_grad
189
+ def apply(self, var):
190
+ module = self.children['module']
191
+ return module.apply(var.clone(clone_update=False))
192
+
193
+ def get_H(self, var):
194
+ return self.children['module'].get_H(var)
@@ -1,12 +1,8 @@
1
- from collections import deque
2
- from collections.abc import Iterable
3
- from operator import itemgetter
4
- from typing import Literal
5
-
6
1
  import torch
7
2
 
8
- from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
9
- from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
3
+ from ...core import Chainable, Module, Target, Transform
4
+ from ...core.reformulation import Reformulation
5
+ from ...utils import Distributions, NumberList, TensorList
10
6
 
11
7
 
12
8
  class Dropout(Transform):
@@ -121,8 +117,8 @@ class PerturbWeights(Module):
121
117
  Args:
122
118
  alpha (float, optional): multiplier for perturbation magnitude. Defaults to 0.1.
123
119
  relative (bool, optional): whether to multiply perturbation by mean absolute value of the parameter. Defaults to True.
124
- graft (bool, optional):
125
- if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.
120
+ distribution (bool, optional):
121
+ distribution of the random perturbation. Defaults to False.
126
122
  """
127
123
  def __init__(self, alpha: float = 0.1, relative:bool=True, distribution:Distributions = 'normal'):
128
124
  defaults = dict(alpha=alpha, relative=relative, distribution=distribution, perturb=True)
@@ -1,4 +1,5 @@
1
- from collections.abc import Callable
1
+ import warnings
2
+ from collections.abc import Callable, Sequence, Iterable
2
3
  from typing import cast
3
4
 
4
5
  import torch
@@ -22,59 +23,78 @@ def _split(
22
23
  if var.update is not None:
23
24
  split_update = [u for i,u in enumerate(var.update) if i in idxs]
24
25
 
25
- split_var = var.clone(clone_update=False)
26
+ split_var = var.clone(clone_update=False, parent=var)
26
27
  split_var.params = split_params
27
28
  split_var.grad = split_grad
28
29
  split_var.update = split_update
29
30
 
30
31
  split_var = module.step(split_var)
31
32
 
32
- if (var.grad is None) and (split_var.grad is not None):
33
- var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
33
+ # those should be set due to var being parent
34
+ if split_var.grad is not None:
35
+ assert var.grad is not None
36
+
37
+ if split_var.loss is not None:
38
+ assert var.loss is not None
34
39
 
35
40
  if split_var.update is not None:
36
41
 
42
+ # make sure update is set, it will be filled with ``true`` and ``false`` tensors
37
43
  if var.update is None:
38
44
  if var.grad is None: var.update = [cast(torch.Tensor, None) for _ in var.params]
39
45
  else: var.update = [g.clone() for g in var.grad]
40
46
 
47
+ # set all tensors from this split
41
48
  for idx, u in zip(idxs, split_var.update):
42
49
  var.update[idx] = u
43
50
 
44
- var.update_attrs_from_clone_(split_var)
45
51
  return var
46
52
 
47
- class Split(Module):
48
- """Apply `true` modules to all parameters filtered by `filter`, apply `false` modules to all other parameters.
53
+ _SingleFilter = Callable[[torch.Tensor], bool] | torch.Tensor | Iterable[torch.Tensor] | torch.nn.Module | Iterable[torch.nn.Module]
54
+ Filter = _SingleFilter | Iterable[_SingleFilter]
49
55
 
50
- Args:
51
- filter (Callable[[torch.Tensor], bool]): a function that takes in a parameter tensor and returns a boolean value.
52
- true (Chainable | None): modules that are applied to tensors where :code:`filter` returned True.
53
- false (Chainable | None): modules that are applied to tensors where :code:`filter` returned False.
54
-
55
- Examples:
56
- standard Muon with Adam fallback
57
-
58
- .. code-block:: python
59
-
60
- opt = tz.Modular(
61
- model.head.parameters(),
62
- tz.m.Split(
63
- # apply muon only to 2D+ parameters
64
- filter = lambda t: t.ndim >= 2,
65
- true = [
66
- tz.m.HeavyBall(),
67
- tz.m.Orthogonalize(),
68
- tz.m.LR(1e-2),
69
- ],
70
- false = tz.m.Adam()
71
- ),
72
- tz.m.LR(1e-2)
73
- )
56
+ def _make_filter(filter: Filter):
57
+ if callable(filter): return filter
58
+ if isinstance(filter, torch.Tensor):
59
+ return lambda x: x is filter
60
+ if isinstance(filter, torch.nn.Module):
61
+ return _make_filter(filter.parameters())
74
62
 
63
+ # iterable
64
+ filters = [_make_filter(f) for f in filter]
65
+ return lambda x: any(f(x) for f in filters)
75
66
 
67
+ class Split(Module):
68
+ """Apply ``true`` modules to all parameters filtered by ``filter``, apply ``false`` modules to all other parameters.
69
+
70
+ Args:
71
+ filter (Filter, bool]):
72
+ a filter that selects tensors to be optimized by ``true``.
73
+ - tensor or iterable of tensors (e.g. ``encoder.parameters()``).
74
+ - function that takes in tensor and outputs a bool (e.g. ``lambda x: x.ndim >= 2``).
75
+ - a sequence of above (acts as "or", so returns true if any of them is true).
76
+
77
+ true (Chainable | None): modules that are applied to tensors where ``filter`` is ``True``.
78
+ false (Chainable | None): modules that are applied to tensors where ``filter`` is ``False``.
79
+
80
+ ### Examples:
81
+
82
+ Muon with Adam fallback using same hyperparams as https://github.com/KellerJordan/Muon
83
+
84
+ ```python
85
+ opt = tz.Modular(
86
+ model.parameters(),
87
+ tz.m.NAG(0.95),
88
+ tz.m.Split(
89
+ lambda p: p.ndim >= 2,
90
+ true = tz.m.Orthogonalize(),
91
+ false = [tz.m.Adam(0.9, 0.95), tz.m.Mul(1/66)],
92
+ ),
93
+ tz.m.LR(1e-2),
94
+ )
95
+ ```
76
96
  """
77
- def __init__(self, filter: Callable[[torch.Tensor], bool], true: Chainable | None, false: Chainable | None):
97
+ def __init__(self, filter: Filter, true: Chainable | None, false: Chainable | None):
78
98
  defaults = dict(filter=filter)
79
99
  super().__init__(defaults)
80
100
 
@@ -84,7 +104,7 @@ class Split(Module):
84
104
  def step(self, var):
85
105
 
86
106
  params = var.params
87
- filter = self.settings[params[0]]['filter']
107
+ filter = _make_filter(self.settings[params[0]]['filter'])
88
108
 
89
109
  true_idxs = []
90
110
  false_idxs = []
@@ -92,11 +112,11 @@ class Split(Module):
92
112
  if filter(p): true_idxs.append(i)
93
113
  else: false_idxs.append(i)
94
114
 
95
- if 'true' in self.children:
115
+ if 'true' in self.children and len(true_idxs) > 0:
96
116
  true = self.children['true']
97
117
  var = _split(true, idxs=true_idxs, params=params, var=var)
98
118
 
99
- if 'false' in self.children:
119
+ if 'false' in self.children and len(false_idxs) > 0:
100
120
  false = self.children['false']
101
121
  var = _split(false, idxs=false_idxs, params=params, var=var)
102
122