torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,62 @@
1
+ import math
2
+
3
+ from typing import Literal
4
+ import torch
5
+
6
+ from ...core import Modular, Module, Var, Chainable
7
+ from ...utils import NumberList, TensorList
8
+
9
+
10
+ class EscapeAnnealing(Module):
11
+ """If parameters stop changing, this runs a backward annealing random search"""
12
+ def __init__(self, max_region:float = 1, max_iter:int = 1000, tol=1e-6, n_tol: int = 10):
13
+ defaults = dict(max_region=max_region, max_iter=max_iter, tol=tol, n_tol=n_tol)
14
+ super().__init__(defaults)
15
+
16
+
17
+ @torch.no_grad
18
+ def step(self, var):
19
+ closure = var.closure
20
+ if closure is None: raise RuntimeError("Escape requries closure")
21
+
22
+ params = TensorList(var.params)
23
+ settings = self.settings[params[0]]
24
+ max_region = self.get_settings(params, 'max_region', cls=NumberList)
25
+ max_iter = settings['max_iter']
26
+ tol = settings['tol']
27
+ n_tol = settings['n_tol']
28
+
29
+ n_bad = self.global_state.get('n_bad', 0)
30
+
31
+ prev_params = self.get_state(params, 'prev_params', cls=TensorList)
32
+ diff = params-prev_params
33
+ prev_params.copy_(params)
34
+
35
+ if diff.abs().global_max() <= tol:
36
+ n_bad += 1
37
+
38
+ else:
39
+ n_bad = 0
40
+
41
+ self.global_state['n_bad'] = n_bad
42
+
43
+ # no progress
44
+ f_0 = var.get_loss(False)
45
+ if n_bad >= n_tol:
46
+ for i in range(1, max_iter+1):
47
+ alpha = max_region * (i / max_iter)
48
+ pert = params.sphere_like(radius=alpha)
49
+
50
+ params.add_(pert)
51
+ f_star = closure(False)
52
+
53
+ if math.isfinite(f_star) and f_star < f_0-1e-12:
54
+ var.update = None
55
+ var.stop = True
56
+ var.skip_update = True
57
+ return var
58
+
59
+ params.sub_(pert)
60
+
61
+ self.global_state['n_bad'] = 0
62
+ return var
@@ -0,0 +1,136 @@
1
+ import torch
2
+
3
+ from ...core import Chainable, Module
4
+
5
+
6
+ # class GradientAccumulation(Module):
7
+ # """Uses :code:`n` steps to accumulate gradients, after :code:`n` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
8
+
9
+ # Accumulating gradients for :code:`n` steps is equivalent to increasing batch size by :code:`n`. Increasing the batch size
10
+ # is more computationally efficient, but sometimes it is not feasible due to memory constraints.
11
+
12
+ # .. note::
13
+ # Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.
14
+
15
+ # Args:
16
+ # modules (Chainable): modules that perform a step every :code:`n` steps using the accumulated gradients.
17
+ # n (int): number of gradients to accumulate.
18
+ # mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
19
+ # stop (bool, optional):
20
+ # this module prevents next modules from stepping unless :code:`n` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.
21
+
22
+ # Examples:
23
+ # Adam with gradients accumulated for 16 batches.
24
+
25
+ # .. code-block:: python
26
+
27
+ # opt = tz.Modular(
28
+ # model.parameters(),
29
+ # tz.m.GradientAccumulation(
30
+ # [tz.m.Adam(), tz.m.LR(1e-2)],
31
+ # n=16
32
+ # )
33
+ # )
34
+
35
+ # """
36
+ # def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
37
+ # defaults = dict(n=n, mean=mean, stop=stop)
38
+ # super().__init__(defaults)
39
+ # self.set_child('modules', modules)
40
+
41
+
42
+ # @torch.no_grad
43
+ # def step(self, var):
44
+ # accumulator = self.get_state(var.params, 'accumulator')
45
+ # settings = self.defaults
46
+ # n = settings['n']; mean = settings['mean']; stop = settings['stop']
47
+ # step = self.global_state['step'] = self.global_state.get('step', 0) + 1
48
+
49
+ # # add update to accumulator
50
+ # torch._foreach_add_(accumulator, var.get_update())
51
+
52
+ # # step with accumulated updates
53
+ # if step % n == 0:
54
+ # if mean:
55
+ # torch._foreach_div_(accumulator, n)
56
+
57
+ # var.update = [a.clone() for a in accumulator]
58
+ # var = self.children['modules'].step(var)
59
+
60
+ # # zero accumulator
61
+ # torch._foreach_zero_(accumulator)
62
+
63
+ # else:
64
+ # # prevent update
65
+ # if stop:
66
+ # var.update = None
67
+ # var.stop=True
68
+ # var.skip_update=True
69
+
70
+ # return var
71
+
72
+
73
+
74
+
75
+ class GradientAccumulation(Module):
76
+ """Uses ``n`` steps to accumulate gradients, after ``n`` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
77
+
78
+ Accumulating gradients for ``n`` steps is equivalent to increasing batch size by ``n``. Increasing the batch size
79
+ is more computationally efficient, but sometimes it is not feasible due to memory constraints.
80
+
81
+ Note:
82
+ Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.
83
+
84
+ Args:
85
+ n (int): number of gradients to accumulate.
86
+ mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
87
+ stop (bool, optional):
88
+ this module prevents next modules from stepping unless ``n`` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.
89
+
90
+ ## Examples:
91
+
92
+ Adam with gradients accumulated for 16 batches.
93
+
94
+ ```python
95
+ opt = tz.Modular(
96
+ model.parameters(),
97
+ tz.m.GradientAccumulation(),
98
+ tz.m.Adam(),
99
+ tz.m.LR(1e-2),
100
+ )
101
+ ```
102
+ """
103
+ def __init__(self, n: int, mean=True, stop=True):
104
+ defaults = dict(n=n, mean=mean, stop=stop)
105
+ super().__init__(defaults)
106
+
107
+
108
+ @torch.no_grad
109
+ def step(self, var):
110
+ accumulator = self.get_state(var.params, 'accumulator')
111
+ settings = self.defaults
112
+ n = settings['n']; mean = settings['mean']; stop = settings['stop']
113
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
114
+
115
+ # add update to accumulator
116
+ torch._foreach_add_(accumulator, var.get_update())
117
+
118
+ # step with accumulated updates
119
+ if step % n == 0:
120
+ if mean:
121
+ torch._foreach_div_(accumulator, n)
122
+
123
+ var.update = accumulator
124
+
125
+ # zero accumulator
126
+ self.clear_state_keys('accumulator')
127
+
128
+ else:
129
+ # prevent update
130
+ if stop:
131
+ var.update = None
132
+ var.stop=True
133
+ var.skip_update=True
134
+
135
+ return var
136
+
@@ -0,0 +1,59 @@
1
+ from collections.abc import Callable
2
+ from abc import ABC, abstractmethod
3
+ import torch
4
+ from ...core import Module
5
+ from ...core import Chainable
6
+
7
+ class HomotopyBase(Module):
8
+ def __init__(self, defaults: dict | None = None):
9
+ super().__init__(defaults)
10
+
11
+ @abstractmethod
12
+ def loss_transform(self, loss: torch.Tensor) -> torch.Tensor:
13
+ """transform the loss"""
14
+
15
+ @torch.no_grad
16
+ def step(self, var):
17
+ if var.loss is not None:
18
+ var.loss = self.loss_transform(var.loss)
19
+
20
+ closure = var.closure
21
+ if closure is None: raise RuntimeError("SquareHomotopy requires closure")
22
+
23
+ def homotopy_closure(backward=True):
24
+ if backward:
25
+ with torch.enable_grad():
26
+ loss = self.loss_transform(closure(False))
27
+ grad = torch.autograd.grad(loss, var.params, allow_unused=True)
28
+ for p,g in zip(var.params, grad):
29
+ p.grad = g
30
+ else:
31
+ loss = self.loss_transform(closure(False))
32
+
33
+ return loss
34
+
35
+ var.closure = homotopy_closure
36
+ return var
37
+
38
+ class SquareHomotopy(HomotopyBase):
39
+ def __init__(self): super().__init__()
40
+ def loss_transform(self, loss): return loss.square().copysign(loss)
41
+
42
+ class SqrtHomotopy(HomotopyBase):
43
+ def __init__(self): super().__init__()
44
+ def loss_transform(self, loss): return (loss+1e-12).sqrt()
45
+
46
+ class ExpHomotopy(HomotopyBase):
47
+ def __init__(self): super().__init__()
48
+ def loss_transform(self, loss): return loss.exp()
49
+
50
+ class LogHomotopy(HomotopyBase):
51
+ def __init__(self): super().__init__()
52
+ def loss_transform(self, loss): return (loss+1e-12).log()
53
+
54
+ class LambdaHomotopy(HomotopyBase):
55
+ def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
56
+ defaults = dict(fn=fn)
57
+ super().__init__(defaults)
58
+
59
+ def loss_transform(self, loss): return self.defaults['fn'](loss)
@@ -0,0 +1,383 @@
1
+ from collections import deque
2
+ from collections.abc import Iterable, Sequence
3
+ from functools import partial
4
+ from operator import itemgetter
5
+ from typing import Literal
6
+
7
+ import torch
8
+
9
+ from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
10
+ from ...utils import (
11
+ Distributions,
12
+ Metrics,
13
+ NumberList,
14
+ TensorList,
15
+ set_storage_,
16
+ tofloat,
17
+ unpack_dicts,
18
+ unpack_states,
19
+ )
20
+
21
+
22
+ class Previous(TensorwiseTransform):
23
+ """Maintains an update from n steps back, for example if n=1, returns previous update"""
24
+ def __init__(self, n=1, target: Target = 'update'):
25
+ defaults = dict(n=n)
26
+ super().__init__(uses_grad=False, defaults=defaults, target=target)
27
+
28
+
29
+ @torch.no_grad
30
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
31
+ n = setting['n']
32
+
33
+ if 'history' not in state:
34
+ state['history'] = deque(maxlen=n+1)
35
+
36
+ state['history'].append(tensor)
37
+
38
+ return state['history'][0]
39
+
40
+
41
+ class LastDifference(Transform):
42
+ """Outputs difference between past two updates."""
43
+ def __init__(self,target: Target = 'update'):
44
+ super().__init__({}, target=target)
45
+
46
+ @torch.no_grad
47
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
48
+ prev_tensors = unpack_states(states, tensors, 'prev_tensors') # initialized to 0
49
+ difference = torch._foreach_sub(tensors, prev_tensors)
50
+ for p, c in zip(prev_tensors, tensors): p.set_(c)
51
+ return difference
52
+
53
+ class LastGradDifference(Module):
54
+ """Outputs difference between past two gradients."""
55
+ def __init__(self):
56
+ super().__init__({})
57
+
58
+ @torch.no_grad
59
+ def step(self, var):
60
+ grad = var.get_grad()
61
+ prev_grad = self.get_state(var.params, 'prev_grad') # initialized to 0
62
+ difference = torch._foreach_sub(grad, prev_grad)
63
+ for p, c in zip(prev_grad, grad): p.copy_(c)
64
+ var.update = list(difference)
65
+ return var
66
+
67
+ class LastParamDifference(Module):
68
+ """Outputs difference between past two parameters, which is the effective previous update."""
69
+ def __init__(self):
70
+ super().__init__({})
71
+
72
+ @torch.no_grad
73
+ def step(self, var):
74
+ params = var.params
75
+ prev_params = self.get_state(var.params, 'prev_params') # initialized to 0
76
+ difference = torch._foreach_sub(params, prev_params)
77
+ for p, c in zip(prev_params, params): p.copy_(c)
78
+ var.update = list(difference)
79
+ return var
80
+
81
+
82
+
83
+ class LastProduct(Transform):
84
+ """Outputs difference between past two updates."""
85
+ def __init__(self,target: Target = 'update'):
86
+ super().__init__({}, uses_grad=False, target=target)
87
+
88
+ @torch.no_grad
89
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
90
+ prev = unpack_states(states, tensors, 'prev', init=torch.ones_like) # initialized to 1 for prod
91
+ prod = torch._foreach_mul(tensors, prev)
92
+ for p, c in zip(prev, tensors): p.set_(c)
93
+ return prod
94
+
95
+ class LastRatio(Transform):
96
+ """Outputs ratio between past two updates, the numerator is determined by :code:`numerator` argument."""
97
+ def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', target: Target = 'update'):
98
+ defaults = dict(numerator=numerator)
99
+ super().__init__(defaults, uses_grad=False, target=target)
100
+
101
+ @torch.no_grad
102
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
103
+ prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
104
+ numerator = settings[0]['numerator']
105
+ if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
106
+ else: ratio = torch._foreach_div(prev, tensors)
107
+ for p, c in zip(prev, tensors): p.set_(c)
108
+ return ratio
109
+
110
+ class LastAbsoluteRatio(Transform):
111
+ """Outputs ratio between absolute values of past two updates the numerator is determined by :code:`numerator` argument."""
112
+ def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8, target: Target = 'update'):
113
+ defaults = dict(numerator=numerator, eps=eps)
114
+ super().__init__(defaults, uses_grad=False, target=target)
115
+
116
+ @torch.no_grad
117
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
118
+ prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
119
+ numerator = settings[0]['numerator']
120
+ eps = NumberList(s['eps'] for s in settings)
121
+
122
+ torch._foreach_abs_(tensors)
123
+ torch._foreach_clamp_min_(prev, eps)
124
+
125
+ if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
126
+ else: ratio = torch._foreach_div(prev, tensors)
127
+ for p, c in zip(prev, tensors): p.set_(c)
128
+ return ratio
129
+
130
+ class GradSign(Transform):
131
+ """Copies gradient sign to update."""
132
+ def __init__(self, target: Target = 'update'):
133
+ super().__init__({}, uses_grad=True, target=target)
134
+
135
+ @torch.no_grad
136
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
137
+ assert grads is not None
138
+ return [t.copysign_(g) for t,g in zip(tensors, grads)]
139
+
140
+ class UpdateSign(Transform):
141
+ """Outputs gradient with sign copied from the update."""
142
+ def __init__(self, target: Target = 'update'):
143
+ super().__init__({}, uses_grad=True, target=target)
144
+
145
+ @torch.no_grad
146
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
147
+ assert grads is not None
148
+ return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place
149
+
150
+ class GraftToGrad(Transform):
151
+ """Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient."""
152
+ def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6, target: Target = 'update'):
153
+ defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
154
+ super().__init__(defaults, uses_grad=True, target=target)
155
+
156
+ @torch.no_grad
157
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
158
+ assert grads is not None
159
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
160
+ return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)
161
+
162
+ class GraftGradToUpdate(Transform):
163
+ """Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update."""
164
+ def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6, target: Target = 'update'):
165
+ defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
166
+ super().__init__(defaults, uses_grad=True, target=target)
167
+
168
+ @torch.no_grad
169
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
170
+ assert grads is not None
171
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
172
+ return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)
173
+
174
+
175
+ class GraftToParams(Transform):
176
+ """Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than :code:`eps`."""
177
+ def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-4, target: Target = 'update'):
178
+ defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
179
+ super().__init__(defaults, uses_grad=False, target=target)
180
+
181
+ @torch.no_grad
182
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
183
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
184
+ return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)
185
+
186
+ class Relative(Transform):
187
+ """Multiplies update by absolute parameter values to make it relative to their magnitude, :code:`min_value` is minimum allowed value to avoid getting stuck at 0."""
188
+ def __init__(self, min_value:float = 1e-4, target: Target = 'update'):
189
+ defaults = dict(min_value=min_value)
190
+ super().__init__(defaults, uses_grad=False, target=target)
191
+
192
+ @torch.no_grad
193
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
194
+ mul = TensorList(params).abs().clamp_([s['min_value'] for s in settings])
195
+ torch._foreach_mul_(tensors, mul)
196
+ return tensors
197
+
198
+ class FillLoss(Module):
199
+ """Outputs tensors filled with loss value times :code:`alpha`"""
200
+ def __init__(self, alpha: float = 1, backward: bool = True):
201
+ defaults = dict(alpha=alpha, backward=backward)
202
+ super().__init__(defaults)
203
+
204
+ @torch.no_grad
205
+ def step(self, var):
206
+ alpha = self.get_settings(var.params, 'alpha')
207
+ loss = var.get_loss(backward=self.defaults['backward'])
208
+ var.update = [torch.full_like(p, loss*a) for p,a in zip(var.params, alpha)]
209
+ return var
210
+
211
+ class MulByLoss(Module):
212
+ """Multiplies update by loss times :code:`alpha`"""
213
+ def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
214
+ defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
215
+ super().__init__(defaults)
216
+
217
+ @torch.no_grad
218
+ def step(self, var):
219
+ alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
220
+ loss = var.get_loss(backward=self.defaults['backward'])
221
+ mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
222
+ torch._foreach_mul_(var.update, mul)
223
+ return var
224
+
225
+ class DivByLoss(Module):
226
+ """Divides update by loss times :code:`alpha`"""
227
+ def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
228
+ defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
229
+ super().__init__(defaults)
230
+
231
+ @torch.no_grad
232
+ def step(self, var):
233
+ alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
234
+ loss = var.get_loss(backward=self.defaults['backward'])
235
+ mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
236
+ torch._foreach_div_(var.update, mul)
237
+ return var
238
+
239
+
240
+ class NoiseSign(Transform):
241
+ """Outputs random tensors with sign copied from the update."""
242
+ def __init__(self, distribution:Distributions = 'normal', variance:float | None = None):
243
+ defaults = dict(distribution=distribution, variance=variance)
244
+ super().__init__(defaults, uses_grad=False)
245
+
246
+ @torch.no_grad
247
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
248
+ variance = unpack_dicts(settings, 'variance')
249
+ return TensorList(tensors).sample_like(settings[0]['distribution'], variance=variance).copysign_(tensors)
250
+
251
+ class HpuEstimate(Transform):
252
+ """returns ``y/||s||``, where ``y`` is difference between current and previous update (gradient), ``s`` is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update."""
253
+ def __init__(self):
254
+ defaults = dict()
255
+ super().__init__(defaults, uses_grad=False)
256
+
257
+ def reset_for_online(self):
258
+ super().reset_for_online()
259
+ self.clear_state_keys('prev_params', 'prev_update')
260
+
261
+ @torch.no_grad
262
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
263
+ prev_params, prev_update = self.get_state(params, 'prev_params', 'prev_update') # initialized to 0
264
+ s = torch._foreach_sub(params, prev_params)
265
+ y = torch._foreach_sub(tensors, prev_update)
266
+ for p, c in zip(prev_params, params): p.copy_(c)
267
+ for p, c in zip(prev_update, tensors): p.copy_(c)
268
+ torch._foreach_div_(y, torch.linalg.norm(torch.cat([t.ravel() for t in s])).clip(min=1e-8)) # pylint:disable=not-callable
269
+ self.store(params, 'y', y)
270
+
271
+ @torch.no_grad
272
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
273
+ return [self.state[p]['y'] for p in params]
274
+
275
+ class RandomHvp(Module):
276
+ """Returns a hessian-vector product with a random vector"""
277
+
278
+ def __init__(
279
+ self,
280
+ n_samples: int = 1,
281
+ distribution: Distributions = "normal",
282
+ update_freq: int = 1,
283
+ hvp_method: Literal["autograd", "forward", "central"] = "autograd",
284
+ h=1e-3,
285
+ ):
286
+ defaults = dict(n_samples=n_samples, distribution=distribution, hvp_method=hvp_method, h=h, update_freq=update_freq)
287
+ super().__init__(defaults)
288
+
289
+ @torch.no_grad
290
+ def step(self, var):
291
+ params = TensorList(var.params)
292
+ settings = self.settings[params[0]]
293
+ n_samples = settings['n_samples']
294
+ distribution = settings['distribution']
295
+ hvp_method = settings['hvp_method']
296
+ h = settings['h']
297
+ update_freq = settings['update_freq']
298
+
299
+ step = self.global_state.get('step', 0)
300
+ self.global_state['step'] = step + 1
301
+
302
+ D = None
303
+ if step % update_freq == 0:
304
+
305
+ rgrad = None
306
+ for i in range(n_samples):
307
+ u = params.sample_like(distribution=distribution, variance=1)
308
+
309
+ Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
310
+ h=h, normalize=True, retain_grad=i < n_samples-1)
311
+
312
+ if D is None: D = Hvp
313
+ else: torch._foreach_add_(D, Hvp)
314
+
315
+ if n_samples > 1: torch._foreach_div_(D, n_samples)
316
+ if update_freq != 1:
317
+ assert D is not None
318
+ D_buf = self.get_state(params, "D", cls=TensorList)
319
+ D_buf.set_(D)
320
+
321
+ if D is None:
322
+ D = self.get_state(params, "D", cls=TensorList)
323
+
324
+ var.update = list(D)
325
+ return var
326
+
327
+ @torch.no_grad
328
+ def _load_best_parameters(params: Sequence[torch.Tensor], best_params: Sequence[torch.Tensor]):
329
+ for p_cur, p_best in zip(params, best_params):
330
+ set_storage_(p_cur, p_best)
331
+
332
+ class SaveBest(Module):
333
+ """Saves best parameters found so far, ones that have lowest loss. Put this as the last module.
334
+
335
+ Adds the following attrs:
336
+
337
+ - ``best_params`` - a list of tensors with best parameters.
338
+ - ``best_loss`` - loss value with ``best_params``.
339
+ - ``load_best_parameters`` - a function that sets parameters to the best parameters./
340
+
341
+ ## Examples
342
+ ```python
343
+ def rosenbrock(x, y):
344
+ return (1 - x)**2 + (100 * (y - x**2))**2
345
+
346
+ xy = torch.tensor((-1.1, 2.5), requires_grad=True)
347
+ opt = tz.Modular(
348
+ [xy],
349
+ tz.m.NAG(0.999),
350
+ tz.m.LR(1e-6),
351
+ tz.m.SaveBest()
352
+ )
353
+
354
+ # optimize for 1000 steps
355
+ for i in range(1000):
356
+ loss = rosenbrock(*xy)
357
+ opt.zero_grad()
358
+ loss.backward()
359
+ opt.step(loss=loss) # SaveBest needs closure or loss
360
+
361
+ # NAG overshot, but we saved the best params
362
+ print(f'{rosenbrock(*xy) = }') # >> 3.6583
363
+ print(f"{opt.attrs['best_loss'] = }") # >> 0.000627
364
+
365
+ # load best parameters
366
+ opt.attrs['load_best_params']()
367
+ print(f'{rosenbrock(*xy) = }') # >> 0.000627
368
+ """
369
+ def __init__(self):
370
+ super().__init__()
371
+
372
+ @torch.no_grad
373
+ def step(self, var):
374
+ loss = tofloat(var.get_loss(False))
375
+ lowest_loss = self.global_state.get('lowest_loss', float("inf"))
376
+
377
+ if loss < lowest_loss:
378
+ self.global_state['lowest_loss'] = loss
379
+ best_params = var.attrs['best_params'] = [p.clone() for p in var.params]
380
+ var.attrs['best_loss'] = loss
381
+ var.attrs['load_best_params'] = partial(_load_best_parameters, params=var.params, best_params=best_params)
382
+
383
+ return var