torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,111 +0,0 @@
1
- # idea https://arxiv.org/pdf/2212.09841
2
- import warnings
3
- from collections.abc import Callable
4
- from functools import partial
5
- from typing import Literal
6
-
7
- import torch
8
-
9
- from ...core import Chainable, Module, apply_transform
10
- from ...utils import TensorList, vec_to_tensors
11
- from ...utils.derivatives import (
12
- hessian_list_to_mat,
13
- hessian_mat,
14
- hvp,
15
- hvp_fd_central,
16
- hvp_fd_forward,
17
- jacobian_and_hessian_wrt,
18
- )
19
-
20
-
21
- class StructuredNewton(Module):
22
- """TODO. Please note that this is experimental and isn't guaranteed to work.
23
- Args:
24
- structure (str, optional): structure.
25
- reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
26
- hvp_method (str):
27
- how to calculate hvp_method. Defaults to "autograd".
28
- inner (Chainable | None, optional): inner modules. Defaults to None.
29
-
30
- """
31
- def __init__(
32
- self,
33
- structure: Literal[
34
- "diagonal",
35
- "diagonal1",
36
- "diagonal_abs",
37
- "tridiagonal",
38
- "circulant",
39
- "toeplitz",
40
- "toeplitz_like",
41
- "hankel",
42
- "rank1",
43
- "rank2", # any rank
44
- ]
45
- | str = "diagonal",
46
- reg: float = 1e-6,
47
- hvp_method: Literal["autograd", "forward", "central"] = "autograd",
48
- h: float = 1e-3,
49
- inner: Chainable | None = None,
50
- ):
51
- defaults = dict(reg=reg, hvp_method=hvp_method, structure=structure, h=h)
52
- super().__init__(defaults)
53
-
54
- if inner is not None:
55
- self.set_child('inner', inner)
56
-
57
- @torch.no_grad
58
- def step(self, var):
59
- params = TensorList(var.params)
60
- closure = var.closure
61
- if closure is None: raise RuntimeError('NewtonCG requires closure')
62
-
63
- settings = self.settings[params[0]]
64
- reg = settings['reg']
65
- hvp_method = settings['hvp_method']
66
- structure = settings['structure']
67
- h = settings['h']
68
-
69
- # ------------------------ calculate grad and hessian ------------------------ #
70
- if hvp_method == 'autograd':
71
- grad = var.get_grad(create_graph=True)
72
- def Hvp_fn1(x):
73
- return hvp(params, grad, x, retain_graph=True)
74
- Hvp_fn = Hvp_fn1
75
-
76
- elif hvp_method == 'forward':
77
- grad = var.get_grad()
78
- def Hvp_fn2(x):
79
- return hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1]
80
- Hvp_fn = Hvp_fn2
81
-
82
- elif hvp_method == 'central':
83
- grad = var.get_grad()
84
- def Hvp_fn3(x):
85
- return hvp_fd_central(closure, params, x, h=h, normalize=True)[1]
86
- Hvp_fn = Hvp_fn3
87
-
88
- else: raise ValueError(hvp_method)
89
-
90
- # -------------------------------- inner step -------------------------------- #
91
- update = var.get_update()
92
- if 'inner' in self.children:
93
- update = apply_transform(self.children['inner'], update, params=params, grads=grad, var=var)
94
-
95
- # hessian
96
- if structure.startswith('diagonal'):
97
- H = Hvp_fn([torch.ones_like(p) for p in params])
98
- if structure == 'diagonal1': torch._foreach_clamp_min_(H, 1)
99
- if structure == 'diagonal_abs': torch._foreach_abs_(H)
100
- torch._foreach_add_(H, reg)
101
- torch._foreach_div_(update, H)
102
- var.update = update
103
- return var
104
-
105
- # hessian
106
- raise NotImplementedError(structure)
107
-
108
-
109
-
110
-
111
-
@@ -1,138 +0,0 @@
1
- from collections import deque
2
-
3
- import torch
4
- # import visualbench as vb
5
-
6
- # import torchzero as tz
7
-
8
- from ...core import Transform, Chainable, apply_transform
9
- from ...utils.linalg import inv_sqrt_2x2, matrix_power_eigh, gram_schmidt
10
- from ...utils import TensorList, vec_to_tensors_
11
-
12
-
13
- def inverse_sqrt(M):
14
- if M.shape[-1] == 2: return inv_sqrt_2x2(M, force_pd=True) # general formula for 2x2 matrices
15
- return matrix_power_eigh(M, -1/2)
16
-
17
- def update_subspace_preconditioner_(
18
- grad: torch.Tensor, # store grads and basis as vectors for matmul
19
- basis: torch.Tensor, # ndim, k
20
- accumulator_: torch.Tensor, # k, k
21
- beta: float | None,
22
- ):
23
- projected = basis.T @ grad # k
24
- outer = torch.outer(projected, projected)
25
-
26
- if beta is None: accumulator_.add_(outer)
27
- else: accumulator_.lerp_(outer, 1-beta)
28
-
29
- def apply_subspace_preconditioner(
30
- tensor: torch.Tensor,
31
- basis: torch.Tensor, # ndim, k
32
- accumulator: torch.Tensor,
33
- ):
34
- preconditioner = inverse_sqrt(accumulator) # k,k
35
-
36
- tensor_projected = basis.T @ tensor # k
37
- update_projected = preconditioner @ tensor_projected # k
38
- return basis @ update_projected # d
39
-
40
- class RandomSubspacePreconditioning(Transform):
41
- """Whitens in random slowly changing subspace. Please note that this is experimental and isn't guaranteed to work."""
42
- def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
43
- defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
44
- super().__init__(defaults, uses_grad=False)
45
-
46
- if inner is not None: self.set_child('inner', inner)
47
-
48
- def apply(self, tensors, params, grads, loss, states, settings):
49
- settings = settings[0]
50
- g = torch.cat([t.view(-1) for t in tensors])
51
- k = settings['k']
52
- beta = settings['beta']
53
- basis_beta = settings['basis_beta']
54
-
55
- if 'basis' not in self.global_state:
56
- self.global_state['basis'] = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
57
- self.global_state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
58
-
59
- basis = self.global_state['basis']
60
- accumulator = self.global_state['accumulator']
61
-
62
- if basis_beta is not None:
63
- basis.lerp_(torch.randn_like(basis), 1-basis_beta)
64
-
65
- update_subspace_preconditioner_(g, basis, accumulator, beta)
66
-
67
- if 'inner' in self.children:
68
- tensors = apply_transform(self.children['inner'], tensors, params, grads)
69
- g = torch.cat([t.view(-1) for t in tensors])
70
-
71
- try:
72
- preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
73
- except torch.linalg.LinAlgError:
74
- preconditioned = g.clip(-0.1, 0.1)
75
- vec_to_tensors_(preconditioned, tensors)
76
-
77
- return tensors
78
-
79
-
80
- class HistorySubspacePreconditioning(Transform):
81
- """Whitens in subspace spanned by history of gradient differences.
82
- Please note that this is experimental and isn't guaranteed to work.
83
-
84
- Args:
85
- beta - for preconditioner itself in the basis.
86
- basis_beta - how much basis is allowed to change.
87
- """
88
- def __init__(self, k: int, beta: float | None = 0.99, basis_beta=0.99, inner: Chainable | None = None):
89
- defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
90
- super().__init__(defaults, uses_grad=False)
91
-
92
- if inner is not None: self.set_child('inner', inner)
93
-
94
- def apply(self, tensors, params, grads, loss, states, settings):
95
- settings = settings[0]
96
-
97
- g = torch.cat([t.view(-1) for t in tensors])
98
- k = settings['k']
99
- beta = settings['beta']
100
- basis_beta = settings['basis_beta']
101
-
102
- if 'history' not in self.global_state:
103
- self.global_state['history'] = deque(maxlen=k)
104
- self.global_state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
105
- self.global_state['basis'] = torch.ones(g.numel(), k, device=g.device, dtype=g.dtype)
106
-
107
-
108
- history: deque = self.global_state['history']
109
- accumulator = self.global_state['accumulator']
110
- basis = self.global_state['basis']
111
-
112
- history.append(g)
113
- if len(history) < k:
114
- basis_t = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
115
- history_basis = torch.stack(tuple(history), -1)
116
- basis_t[:, -len(history):] = history_basis
117
-
118
- else:
119
- basis_t = torch.stack(tuple(history), -1)
120
-
121
- basis_t[:,:-1] = basis_t[:, :-1] - basis_t[:, 1:]
122
- basis_t = (basis_t - basis_t.mean()) / basis_t.std()
123
-
124
- basis.lerp_(basis_t, 1-basis_beta)
125
- update_subspace_preconditioner_(g, basis, accumulator, beta)
126
-
127
- if 'inner' in self.children:
128
- tensors = apply_transform(self.children['inner'], tensors, params, grads)
129
- g = torch.cat([t.view(-1) for t in tensors])
130
-
131
- try:
132
- preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
133
- except torch.linalg.LinAlgError:
134
- preconditioned = g.clip(-0.1,0.1)
135
- vec_to_tensors_(preconditioned, tensors)
136
-
137
- return tensors
138
-
@@ -1,38 +0,0 @@
1
- from collections import deque
2
-
3
- import torch
4
-
5
- from ...core import Chainable, TensorwiseTransform
6
- from ...utils.linalg import matrix_power_eigh
7
-
8
-
9
- class TAda(TensorwiseTransform):
10
- """3rd order whitening (maybe normalizes skewness). Please note that this is experimental and isn't guaranteed to work."""
11
- def __init__(self, history_size: int = 100, reg: float = 1e-8, update_freq: int = 1, concat_params: bool = True, inner: Chainable | None = None):
12
- defaults = dict(history_size=history_size, reg=reg)
13
- super().__init__(defaults, uses_grad=False, update_freq=update_freq, inner=inner, concat_params=concat_params)
14
-
15
- @torch.no_grad
16
- def update_tensor(self, tensor, param, grad, loss, state, settings):
17
- reg = settings['reg']
18
- if 'history' not in state:
19
- state['history'] = deque(maxlen=settings['history_size'])
20
-
21
- g = tensor.view(-1)
22
- history = state['history']
23
- history.append(g.clone())
24
-
25
- I = torch.eye(tensor.numel(), device=tensor.device, dtype=tensor.dtype).mul_(reg)
26
- g_k = history[0]
27
- outer = torch.outer(g_k, g_k).mul_(torch.dot(g_k, g).clip(min=reg))
28
- if len(history) > 1:
29
- for g_k in list(history)[1:]:
30
- outer += torch.outer(g_k, g_k).mul_(torch.dot(g_k, g).clip(min=reg))
31
-
32
- state['outer'] = outer.add_(I)
33
-
34
- @torch.no_grad
35
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
36
- outer = state['outer']
37
- P = matrix_power_eigh(outer, -1/2)
38
- return (P @ tensor.ravel()).view_as(tensor)
@@ -1,73 +0,0 @@
1
- from operator import itemgetter
2
-
3
- import torch
4
-
5
- from .line_search import LineSearch
6
-
7
-
8
- class TrustRegion(LineSearch):
9
- """Basic first order trust region method. Re-evaluates the function after stepping, if value decreased sufficiently,
10
- step size is increased. If value increased, step size is decreased. This is prone to collapsing.
11
-
12
- Args:
13
- nplus (float, optional): multiplier to step size on successful steps. Defaults to 1.5.
14
- nminus (float, optional): multiplier to step size on unsuccessful steps. Defaults to 0.75.
15
- c (float, optional): descent condition. Defaults to 1e-4.
16
- init (float, optional): initial step size. Defaults to 1.
17
- backtrack (bool, optional): whether to undo the step if value increased. Defaults to True.
18
- adaptive (bool, optional):
19
- If enabled, when multiple consecutive steps have been successful or unsuccessful,
20
- the corresponding multipliers are increased, otherwise they are reset. Defaults to True.
21
- """
22
- def __init__(self, nplus: float=1.5, nminus: float=0.75, c: float=1e-4, init: float = 1, backtrack: bool = True, adaptive: bool = True):
23
- defaults = dict(nplus=nplus, nminus=nminus, c=c, init=init, backtrack=backtrack, adaptive=adaptive)
24
- super().__init__(defaults)
25
-
26
- @torch.no_grad
27
- def search(self, update, var):
28
-
29
- nplus, nminus, c, init, backtrack, adaptive = itemgetter('nplus','nminus','c','init','backtrack', 'adaptive')(self.settings[var.params[0]])
30
- step_size = self.global_state.setdefault('step_size', init)
31
- previous_success = self.global_state.setdefault('previous_success', False)
32
- nplus_mul = self.global_state.setdefault('nplus_mul', 1)
33
- nminus_mul = self.global_state.setdefault('nminus_mul', 1)
34
-
35
-
36
- f_0 = self.evaluate_step_size(0, var, backward=False)
37
-
38
- # directional derivative (0 if c = 0 because it is not needed)
39
- if c == 0: d = 0
40
- else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), update))
41
-
42
- # test step size
43
- sufficient_f = f_0 + c * step_size * min(d, 0) # pyright:ignore[reportArgumentType]
44
-
45
- f_1 = self.evaluate_step_size(step_size, var, backward=False)
46
-
47
- proposed = step_size
48
-
49
- # very good step
50
- if f_1 < sufficient_f:
51
- self.global_state['step_size'] *= nplus * nplus_mul
52
-
53
- # two very good steps in a row - increase nplus_mul
54
- if adaptive:
55
- if previous_success: self.global_state['nplus_mul'] *= nplus
56
- else: self.global_state['nplus_mul'] = 1
57
-
58
- # acceptable step step
59
- #elif f_1 <= f_0: pass
60
-
61
- # bad step
62
- if f_1 >= f_0:
63
- self.global_state['step_size'] *= nminus * nminus_mul
64
-
65
- # two bad steps in a row - decrease nminus_mul
66
- if adaptive:
67
- if previous_success: self.global_state['nminus_mul'] *= nminus
68
- else: self.global_state['nminus_mul'] = 1
69
-
70
- if backtrack: proposed = 0
71
- else: proposed *= nminus * nminus_mul
72
-
73
- return proposed
@@ -1,2 +0,0 @@
1
- from .lr import LR, StepSize, Warmup
2
- from .adaptive import PolyakStepSize, RandomStepSize
@@ -1,93 +0,0 @@
1
- """Various step size strategies"""
2
- import random
3
- from typing import Any
4
- from operator import itemgetter
5
- import torch
6
-
7
- from ...core import Transform
8
- from ...utils import TensorList, NumberList, unpack_dicts
9
-
10
-
11
- class PolyakStepSize(Transform):
12
- """Polyak's step-size method.
13
-
14
- Args:
15
- max (float | None, optional): maximum possible step size. Defaults to None.
16
- min_obj_value (int, optional):
17
- (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
18
- use_grad (bool, optional):
19
- if True, uses dot product of update and gradient to compute the step size.
20
- Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
21
- Defaults to True.
22
- parameterwise (bool, optional):
23
- if True, calculate Polyak step-size for each parameter separately,
24
- if False calculate one global step size for all parameters. Defaults to False.
25
- alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
26
- """
27
- def __init__(self, max: float | None = None, min_obj_value: float = 0, use_grad=True, parameterwise=False, alpha: float = 1):
28
-
29
- defaults = dict(alpha=alpha, max=max, min_obj_value=min_obj_value, use_grad=use_grad, parameterwise=parameterwise)
30
- super().__init__(defaults, uses_grad=use_grad)
31
-
32
- @torch.no_grad
33
- def apply(self, tensors, params, grads, loss, states, settings):
34
- assert grads is not None
35
- tensors = TensorList(tensors)
36
- grads = TensorList(grads)
37
- alpha = NumberList(s['alpha'] for s in settings)
38
-
39
- parameterwise, use_grad, max, min_obj_value = itemgetter('parameterwise', 'use_grad', 'max', 'min_obj_value')(settings[0])
40
-
41
- if use_grad: denom = tensors.dot(grads)
42
- else: denom = tensors.dot(tensors)
43
-
44
- if parameterwise:
45
- polyak_step_size: TensorList | Any = (loss - min_obj_value) / denom.where(denom!=0, 1)
46
- polyak_step_size = polyak_step_size.where(denom != 0, 0)
47
- if max is not None: polyak_step_size = polyak_step_size.clamp_max(max)
48
-
49
- else:
50
- if denom.abs() <= torch.finfo(denom.dtype).eps: polyak_step_size = 0 # converged
51
- else: polyak_step_size = (loss - min_obj_value) / denom
52
-
53
- if max is not None:
54
- if polyak_step_size > max: polyak_step_size = max
55
-
56
- tensors.mul_(alpha * polyak_step_size)
57
- return tensors
58
-
59
-
60
- class RandomStepSize(Transform):
61
- """Uses random global or layer-wise step size from `low` to `high`.
62
-
63
- Args:
64
- low (float, optional): minimum learning rate. Defaults to 0.
65
- high (float, optional): maximum learning rate. Defaults to 1.
66
- parameterwise (bool, optional):
67
- if True, generate random step size for each parameter separately,
68
- if False generate one global random step size. Defaults to False.
69
- """
70
- def __init__(self, low: float = 0, high: float = 1, parameterwise=False, seed:int|None=None):
71
- defaults = dict(low=low, high=high, parameterwise=parameterwise,seed=seed)
72
- super().__init__(defaults, uses_grad=False)
73
-
74
- @torch.no_grad
75
- def apply(self, tensors, params, grads, loss, states, settings):
76
- s = settings[0]
77
- parameterwise = s['parameterwise']
78
-
79
- seed = s['seed']
80
- if 'generator' not in self.global_state:
81
- self.global_state['generator'] = random.Random(seed)
82
- generator: random.Random = self.global_state['generator']
83
-
84
- if parameterwise:
85
- low, high = unpack_dicts(settings, 'low', 'high')
86
- lr = [generator.uniform(l, h) for l, h in zip(low, high)]
87
- else:
88
- low = s['low']
89
- high = s['high']
90
- lr = generator.uniform(low, high)
91
-
92
- torch._foreach_mul_(tensors, lr)
93
- return tensors
@@ -1,63 +0,0 @@
1
- """Learning rate"""
2
- import torch
3
-
4
- from ...core import Transform
5
- from ...utils import NumberList, TensorList, generic_eq, unpack_dicts
6
-
7
- def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
8
- """multiplies by lr if lr is not 1"""
9
- if generic_eq(lr, 1): return tensors
10
- if inplace: return tensors.mul_(lr)
11
- return tensors * lr
12
-
13
- class LR(Transform):
14
- """Learning rate. Adding this module also adds support for LR schedulers."""
15
- def __init__(self, lr: float):
16
- defaults=dict(lr=lr)
17
- super().__init__(defaults, uses_grad=False)
18
-
19
- @torch.no_grad
20
- def apply(self, tensors, params, grads, loss, states, settings):
21
- return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)
22
-
23
- class StepSize(Transform):
24
- """this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
25
- def __init__(self, step_size: float, key = 'step_size'):
26
- defaults={"key": key, key: step_size}
27
- super().__init__(defaults, uses_grad=False)
28
-
29
- @torch.no_grad
30
- def apply(self, tensors, params, grads, loss, states, settings):
31
- return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)
32
-
33
-
34
- def _warmup_lr(step: int, start_lr: float | NumberList, end_lr: float | NumberList, steps: float):
35
- """returns warm up lr scalar"""
36
- if step > steps: return end_lr
37
- return start_lr + (end_lr - start_lr) * (step / steps)
38
-
39
- class Warmup(Transform):
40
- """Learning rate warmup, linearly increases learning rate multiplier from :code:`start_lr` to :code:`end_lr` over :code:`steps` steps.
41
-
42
- Args:
43
- start_lr (_type_, optional): initial learning rate multiplier on first step. Defaults to 1e-5.
44
- end_lr (float, optional): learning rate multiplier at the end and after warmup. Defaults to 1.
45
- steps (int, optional): number of steps to perform warmup for. Defaults to 100.
46
- """
47
- def __init__(self, start_lr = 1e-5, end_lr:float = 1, steps = 100):
48
- defaults = dict(start_lr=start_lr,end_lr=end_lr, steps=steps)
49
- super().__init__(defaults, uses_grad=False)
50
-
51
- @torch.no_grad
52
- def apply(self, tensors, params, grads, loss, states, settings):
53
- start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
54
- num_steps = settings[0]['steps']
55
- step = self.global_state.get('step', 0)
56
-
57
- target = lazy_lr(
58
- TensorList(tensors),
59
- lr=_warmup_lr(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
60
- inplace=True
61
- )
62
- self.global_state['step'] = step + 1
63
- return target
@@ -1,166 +0,0 @@
1
- from typing import Literal
2
-
3
- import torch
4
-
5
- from ...core import Module, apply_transform, Chainable
6
- from ...utils import NumberList, TensorList, as_tensorlist
7
- from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
-
9
- class MatrixMomentum(Module):
10
- """
11
- May be useful for ill conditioned stochastic quadratic objectives but I need to test this.
12
- Evaluates hessian vector product on each step (via finite difference or autograd).
13
-
14
- `mu` is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable.
15
-
16
- Args:
17
- mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
18
- beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
19
- hvp_method (str, optional):
20
- How to calculate hessian-vector products.
21
- Exact - "autograd", or finite difference - "forward", "central". Defaults to 'forward'.
22
- h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
23
- hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
24
-
25
- Reference:
26
- Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
27
- """
28
-
29
- def __init__(
30
- self,
31
- mu=0.1,
32
- beta: float = 1,
33
- hvp_method: Literal["autograd", "forward", "central"] = "forward",
34
- h: float = 1e-3,
35
- hvp_tfm: Chainable | None = None,
36
- ):
37
- defaults = dict(mu=mu, beta=beta, hvp_method=hvp_method, h=h)
38
- super().__init__(defaults)
39
-
40
- if hvp_tfm is not None:
41
- self.set_child('hvp_tfm', hvp_tfm)
42
-
43
- @torch.no_grad
44
- def step(self, var):
45
- assert var.closure is not None
46
- prev_update = self.get_state(var.params, 'prev_update', cls=TensorList)
47
- hvp_method = self.settings[var.params[0]]['hvp_method']
48
- h = self.settings[var.params[0]]['h']
49
-
50
- mu,beta = self.get_settings(var.params, 'mu','beta', cls=NumberList)
51
-
52
- if hvp_method == 'autograd':
53
- with torch.enable_grad():
54
- grad = var.get_grad(create_graph=True)
55
- hvp_ = TensorList(hvp(var.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
56
-
57
- elif hvp_method == 'forward':
58
- var.get_grad()
59
- l, hvp_ = hvp_fd_forward(var.closure, var.params, vec=prev_update, g_0=var.grad, h=h, normalize=True)
60
- if var.loss_approx is None: var.loss_approx = l
61
-
62
- elif hvp_method == 'central':
63
- l, hvp_ = hvp_fd_central(var.closure, var.params, vec=prev_update, h=h, normalize=True)
64
- if var.loss_approx is None: var.loss_approx = l
65
-
66
- else:
67
- raise ValueError(hvp_method)
68
-
69
- if 'hvp_tfm' in self.children:
70
- hvp_ = TensorList(apply_transform(self.children['hvp_tfm'], hvp_, params=var.params, grads=var.grad, var=var))
71
-
72
- update = TensorList(var.get_update())
73
-
74
- hvp_ = as_tensorlist(hvp_)
75
- update.add_(prev_update - hvp_*mu)
76
- prev_update.set_(update * beta)
77
- var.update = update
78
- return var
79
-
80
-
81
- class AdaptiveMatrixMomentum(Module):
82
- """
83
- May be useful for ill conditioned stochastic quadratic objectives but I need to test this.
84
- Evaluates hessian vector product on each step (via finite difference or autograd).
85
-
86
- This version estimates mu via a simple heuristic: ||s||/||y||, where s is parameter difference, y is gradient difference.
87
-
88
- Args:
89
- mu_mul (float, optional): multiplier to the estimated mu. Defaults to 1.
90
- beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
91
- hvp_method (str, optional):
92
- How to calculate hessian-vector products.
93
- Exact - "autograd", or finite difference - "forward", "central". Defaults to 'forward'.
94
- h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
95
- hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
96
-
97
- Reference:
98
- Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
99
- """
100
-
101
- def __init__(
102
- self,
103
- mu_mul: float = 1,
104
- beta: float = 1,
105
- eps=1e-4,
106
- hvp_method: Literal["autograd", "forward", "central"] = "forward",
107
- h: float = 1e-3,
108
- hvp_tfm: Chainable | None = None,
109
- ):
110
- defaults = dict(mu_mul=mu_mul, beta=beta, hvp_method=hvp_method, h=h, eps=eps)
111
- super().__init__(defaults)
112
-
113
- if hvp_tfm is not None:
114
- self.set_child('hvp_tfm', hvp_tfm)
115
-
116
- @torch.no_grad
117
- def step(self, var):
118
- assert var.closure is not None
119
- prev_update, prev_params, prev_grad = self.get_state(var.params, 'prev_update', 'prev_params', 'prev_grad', cls=TensorList)
120
-
121
- settings = self.settings[var.params[0]]
122
- hvp_method = settings['hvp_method']
123
- h = settings['h']
124
- eps = settings['eps']
125
-
126
- mu_mul, beta = self.get_settings(var.params, 'mu_mul','beta', cls=NumberList)
127
-
128
- if hvp_method == 'autograd':
129
- with torch.enable_grad():
130
- grad = var.get_grad(create_graph=True)
131
- hvp_ = TensorList(hvp(var.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
132
-
133
- elif hvp_method == 'forward':
134
- var.get_grad()
135
- l, hvp_ = hvp_fd_forward(var.closure, var.params, vec=prev_update, g_0=var.grad, h=h, normalize=True)
136
- if var.loss_approx is None: var.loss_approx = l
137
-
138
- elif hvp_method == 'central':
139
- l, hvp_ = hvp_fd_central(var.closure, var.params, vec=prev_update, h=h, normalize=True)
140
- if var.loss_approx is None: var.loss_approx = l
141
-
142
- else:
143
- raise ValueError(hvp_method)
144
-
145
- if 'hvp_tfm' in self.children:
146
- hvp_ = TensorList(apply_transform(self.children['hvp_tfm'], hvp_, params=var.params, grads=var.grad, var=var))
147
-
148
- # adaptive part
149
- update = TensorList(var.get_update())
150
-
151
- s_k = var.params - prev_params
152
- prev_params.copy_(var.params)
153
-
154
- assert var.grad is not None
155
- y_k = var.grad - prev_grad
156
- prev_grad.copy_(var.grad)
157
-
158
- ada_mu = (s_k.global_vector_norm() / (y_k.global_vector_norm() + eps)) * mu_mul
159
-
160
- # matrix momentum uppdate
161
- hvp_ = as_tensorlist(hvp_)
162
- update.add_(prev_update - hvp_*ada_mu)
163
- prev_update.set_(update * beta)
164
- var.update = update
165
- return var
166
-