torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -0,0 +1,128 @@
1
+ from collections import deque
2
+
3
+ import torch
4
+ # import visualbench as vb
5
+
6
+ # import torchzero as tz
7
+
8
+ from ...core import Transform, Chainable, apply
9
+ from ...utils.linalg import inv_sqrt_2x2, matrix_power_eigh, gram_schmidt
10
+ from ...utils import TensorList, vec_to_tensors_
11
+
12
+
13
+ def inverse_sqrt(M):
14
+ if M.shape[-1] == 2: return inv_sqrt_2x2(M, force_pd=True) # general formula for 2x2 matrices
15
+ return matrix_power_eigh(M, -1/2)
16
+
17
+ def update_subspace_preconditioner_(
18
+ grad: torch.Tensor, # store grads and basis as vectors for matmul
19
+ basis: torch.Tensor, # ndim, k
20
+ accumulator_: torch.Tensor, # k, k
21
+ beta: float | None,
22
+ ):
23
+ projected = basis.T @ grad # k
24
+ outer = torch.outer(projected, projected)
25
+
26
+ if beta is None: accumulator_.add_(outer)
27
+ else: accumulator_.lerp_(outer, 1-beta)
28
+
29
+ def apply_subspace_preconditioner(
30
+ tensor: torch.Tensor,
31
+ basis: torch.Tensor, # ndim, k
32
+ accumulator: torch.Tensor,
33
+ ):
34
+ preconditioner = inverse_sqrt(accumulator) # k,k
35
+
36
+ tensor_projected = basis.T @ tensor # k
37
+ update_projected = preconditioner @ tensor_projected # k
38
+ return basis @ update_projected # d
39
+
40
+ class RandomSubspacePreconditioning(Transform):
41
+ """full matrix rmsprop in random subspace"""
42
+ def __init__(self, k: int, beta: float | None = 0.99):
43
+ defaults = dict(k=k, beta=beta)
44
+ super().__init__(defaults, uses_grad=False)
45
+
46
+ def transform(self, tensors, params, grads, vars):
47
+ settings = self.settings[params[0]]
48
+ g = torch.cat([t.view(-1) for t in tensors])
49
+ k = settings['k']
50
+ beta = settings['beta']
51
+
52
+ if 'basis' not in self.global_state:
53
+ self.global_state['basis'] = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
54
+ self.global_state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
55
+
56
+ basis = self.global_state['basis']
57
+ accumulator = self.global_state['accumulator']
58
+
59
+ update_subspace_preconditioner_(g, basis, accumulator, beta)
60
+ try:
61
+ preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
62
+ except torch.linalg.LinAlgError:
63
+ denom = g.abs().sum()
64
+ if denom <= 1e-10: denom = torch.ones_like(denom)
65
+ preconditioned = g / g.abs().sum()
66
+ vec_to_tensors_(preconditioned, tensors)
67
+
68
+ return tensors
69
+
70
+
71
+ class HistorySubspacePreconditioning(Transform):
72
+ """full matrix rmsprop in subspace spanned by history of gradient differences
73
+
74
+ basis_beta is how much basis is allowed to change, and beta is for preconditioner itself in the basis.
75
+ """
76
+ def __init__(self, k: int, beta: float | None = 0.99, basis_beta=0.99, inner: Chainable | None = None):
77
+ defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
78
+ super().__init__(defaults, uses_grad=False)
79
+
80
+ if inner is not None: self.set_child('inner', inner)
81
+
82
+ def transform(self, tensors, params, grads, vars):
83
+ settings = self.settings[params[0]]
84
+
85
+ g = torch.cat([t.view(-1) for t in tensors])
86
+ k = settings['k']
87
+ beta = settings['beta']
88
+ basis_beta = settings['basis_beta']
89
+
90
+ if 'history' not in self.global_state:
91
+ self.global_state['history'] = deque(maxlen=k)
92
+ self.global_state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
93
+ self.global_state['basis'] = torch.ones(g.numel(), k, device=g.device, dtype=g.dtype)
94
+
95
+
96
+ history: deque = self.global_state['history']
97
+ accumulator = self.global_state['accumulator']
98
+ basis = self.global_state['basis']
99
+
100
+ history.append(g)
101
+ if len(history) < k:
102
+ basis_t = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
103
+ history_basis = torch.stack(tuple(history), -1)
104
+ basis_t[:, -len(history):] = history_basis
105
+
106
+ else:
107
+ basis_t = torch.stack(tuple(history), -1)
108
+
109
+ basis_t[:,:-1] = basis_t[:, :-1] - basis_t[:, 1:]
110
+ basis_t = (basis_t - basis_t.mean()) / basis_t.std()
111
+
112
+ basis.lerp_(basis_t, 1-basis_beta)
113
+ update_subspace_preconditioner_(g, basis, accumulator, beta)
114
+
115
+ if 'inner' in self.children:
116
+ tensors = apply(self.children['inner'], tensors, params, grads, vars)
117
+ g = torch.cat([t.view(-1) for t in tensors])
118
+
119
+ try:
120
+ preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
121
+ except torch.linalg.LinAlgError:
122
+ denom = g.abs().sum()
123
+ if denom <= 1e-10: denom = torch.ones_like(denom)
124
+ preconditioned = g / g.abs().sum()
125
+ vec_to_tensors_(preconditioned, tensors)
126
+
127
+ return tensors
128
+
@@ -0,0 +1,136 @@
1
+ import warnings
2
+ from functools import partial
3
+ from typing import Literal
4
+ from collections.abc import Callable
5
+ import torch
6
+
7
+ from ...core import Chainable, apply, Module
8
+ from ...utils import vec_to_tensors, TensorList
9
+ from ...utils.derivatives import (
10
+ hessian_list_to_mat,
11
+ hessian_mat,
12
+ jacobian_and_hessian_wrt,
13
+ )
14
+ from ..second_order.newton import lu_solve, cholesky_solve, least_squares_solve
15
+
16
+ def tropical_sum(x, dim): return torch.amax(x, dim=dim)
17
+ def tropical_mul(x, y): return x+y
18
+
19
+ def tropical_matmul(x: torch.Tensor, y: torch.Tensor):
20
+ # this imlements matmul by calling mul and sum
21
+
22
+ x_squeeze = False
23
+ y_squeeze = False
24
+
25
+ if x.ndim == 1:
26
+ x_squeeze = True
27
+ x = x.unsqueeze(0)
28
+
29
+ if y.ndim == 1:
30
+ y_squeeze = True
31
+ y = y.unsqueeze(1)
32
+
33
+ res = tropical_sum(tropical_mul(x.unsqueeze(-1), y.unsqueeze(-3)), dim = -2)
34
+
35
+ if x_squeeze: res = res.squeeze(-2)
36
+ if y_squeeze: res = res.squeeze(-1)
37
+
38
+ return res
39
+
40
+ def tropical_dot(x:torch.Tensor, y:torch.Tensor):
41
+ assert x.ndim == 1 and y.ndim == 1
42
+ return tropical_matmul(x.unsqueeze(0), y.unsqueeze(1))
43
+
44
+ def tropical_outer(x:torch.Tensor, y:torch.Tensor):
45
+ assert x.ndim == 1 and y.ndim == 1
46
+ return tropical_matmul(x.unsqueeze(1), y.unsqueeze(0))
47
+
48
+
49
+ def tropical_solve(A: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
50
+ r = b.unsqueeze(1) - A
51
+ return r.amin(dim=-2)
52
+
53
+ def tropical_solve_and_reconstruct(A: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
54
+ r = b.unsqueeze(1) - A
55
+ x = r.amin(dim=-2)
56
+ b_hat = tropical_matmul(A, x)
57
+ return x, b_hat
58
+
59
+ def tikhonov(H: torch.Tensor, reg: float):
60
+ if reg!=0: H += torch.eye(H.size(-1), dtype=H.dtype, device=H.device) * reg
61
+ return H
62
+
63
+
64
+ class TropicalNewton(Module):
65
+ """suston"""
66
+ def __init__(
67
+ self,
68
+ reg: float | None = None,
69
+ hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
70
+ vectorize: bool = True,
71
+ interpolate:bool=False,
72
+ inner: Chainable | None = None,
73
+ ):
74
+ defaults = dict(reg=reg, hessian_method=hessian_method, vectorize=vectorize, interpolate=interpolate)
75
+ super().__init__(defaults)
76
+
77
+ if inner is not None:
78
+ self.set_child('inner', inner)
79
+
80
+ @torch.no_grad
81
+ def step(self, vars):
82
+ params = TensorList(vars.params)
83
+ closure = vars.closure
84
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
85
+
86
+ settings = self.settings[params[0]]
87
+ reg = settings['reg']
88
+ hessian_method = settings['hessian_method']
89
+ vectorize = settings['vectorize']
90
+ interpolate = settings['interpolate']
91
+
92
+ # ------------------------ calculate grad and hessian ------------------------ #
93
+ if hessian_method == 'autograd':
94
+ with torch.enable_grad():
95
+ loss = vars.loss = vars.loss_approx = closure(False)
96
+ g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
97
+ g_list = [t[0] for t in g_list] # remove leading dim from loss
98
+ vars.grad = g_list
99
+ H = hessian_list_to_mat(H_list)
100
+
101
+ elif hessian_method in ('func', 'autograd.functional'):
102
+ strat = 'forward-mode' if vectorize else 'reverse-mode'
103
+ with torch.enable_grad():
104
+ g_list = vars.get_grad(retain_graph=True)
105
+ H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
106
+ method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
107
+
108
+ else:
109
+ raise ValueError(hessian_method)
110
+
111
+ # -------------------------------- inner step -------------------------------- #
112
+ if 'inner' in self.children:
113
+ g_list = apply(self.children['inner'], list(g_list), params=params, grads=list(g_list), vars=vars)
114
+ g = torch.cat([t.view(-1) for t in g_list])
115
+
116
+ # ------------------------------- regulazition ------------------------------- #
117
+ if reg is not None: H = tikhonov(H, reg)
118
+
119
+ # ----------------------------------- solve ---------------------------------- #
120
+ tropical_update, g_hat = tropical_solve_and_reconstruct(H, g)
121
+
122
+ g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
123
+ abs_error = torch.linalg.vector_norm(g-g_hat) # pylint:disable=not-callable
124
+ rel_error = abs_error/g_norm.clip(min=1e-8)
125
+
126
+ if interpolate:
127
+ if rel_error > 1e-8:
128
+
129
+ update = cholesky_solve(H, g)
130
+ if update is None: update = lu_solve(H, g)
131
+ if update is None: update = least_squares_solve(H, g)
132
+
133
+ tropical_update.lerp_(update.ravel(), rel_error.clip(max=1))
134
+
135
+ vars.update = vec_to_tensors(tropical_update, params)
136
+ return vars
@@ -0,0 +1,209 @@
1
+ """
2
+ Arguments that are modified in-place are denoted with "_" at the end.
3
+
4
+ Some functions return one of the arguments which was modified in-place, some return new tensors.
5
+ Make sure to keep track of that to avoid unexpected in-place modifications of buffers. The returned
6
+ storage is always indicated in the docstring.
7
+
8
+ Additional functional variants are present in most module files, e.g. `adam_`, `rmsprop_`, `lion_`, etc.
9
+ """
10
+
11
+ from collections.abc import Callable, Sequence
12
+
13
+ from ..utils import NumberList, TensorList
14
+
15
+ inf = float('inf')
16
+
17
+ def debiased_step_size(
18
+ step,
19
+ beta1: float | NumberList | None = None,
20
+ beta2: float | NumberList | None = None,
21
+ pow: float = 2,
22
+ alpha: float | NumberList = 1,
23
+ ):
24
+ """returns multiplier to step size"""
25
+ if isinstance(beta1, NumberList): beta1 = beta1.fill_none(0)
26
+ if isinstance(beta2, NumberList): beta2 = beta2.fill_none(0)
27
+
28
+ step_size = alpha
29
+ if beta1 is not None:
30
+ bias_correction1 = 1.0 - (beta1 ** step)
31
+ step_size /= bias_correction1
32
+ if beta2 is not None:
33
+ bias_correction2 = 1.0 - (beta2 ** step)
34
+ step_size *= bias_correction2 ** (1/pow)
35
+ return step_size
36
+
37
+ def debias(
38
+ tensors_: TensorList,
39
+ step: int,
40
+ inplace: bool,
41
+ beta1: float | NumberList | None = None,
42
+ beta2: float | NumberList | None = None,
43
+ alpha: float | NumberList = 1,
44
+ pow: float = 2,
45
+ ):
46
+ step_size = debiased_step_size(step=step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
47
+ if inplace: return tensors_.mul_(step_size)
48
+ return tensors_ * step_size
49
+
50
+ def debias_second_momentum(tensors_:TensorList, step: int, beta: float | NumberList, pow: float, inplace:bool):
51
+ """debias 2nd momentum, optionally in-place"""
52
+ bias_correction2 = (1.0 - (beta ** step)) ** (1/pow)
53
+ if inplace: return tensors_.div_(bias_correction2)
54
+ return tensors_ / bias_correction2
55
+
56
+ def lerp_power_(tensors:TensorList, exp_avg_pow_:TensorList, beta:float|NumberList, pow:float) -> TensorList:
57
+ """
58
+ Lerp `exp_avg_pow_` with `tensors ^ pow`
59
+
60
+ Returns `exp_avg_pow_`.
61
+ """
62
+ if pow == 1: return exp_avg_pow_.lerp_(tensors.abs(), 1-beta)
63
+ if pow == 2: return exp_avg_pow_.mul_(beta).addcmul_(tensors, tensors, value = 1-beta)
64
+ if pow % 2 == 0: return exp_avg_pow_.lerp_(tensors.pow(pow), 1-beta)
65
+ return exp_avg_pow_.lerp_(tensors.pow(pow).abs_(), 1-beta)
66
+
67
+ def add_power_(tensors:TensorList, sum_:TensorList, pow:float) -> TensorList:
68
+ """
69
+ Add `tensors ^ pow` to `sum_`
70
+
71
+ Returns `sum_`.
72
+ """
73
+ if pow == 1: return sum_.add_(tensors.abs())
74
+ if pow == 2: return sum_.addcmul_(tensors, tensors)
75
+ if pow % 2 == 0: return sum_.add_(tensors.pow(pow))
76
+ return sum_.add_(tensors.pow(pow).abs_())
77
+
78
+
79
+ def root(tensors_:TensorList, p:float, inplace: bool):
80
+ """
81
+ Root of tensors, optionally in-place.
82
+
83
+ Returns `tensors_` if `inplace` else new tensors.
84
+ """
85
+ if inplace:
86
+ if p == 1: return tensors_.abs_()
87
+ if p == 2: return tensors_.sqrt_()
88
+ return tensors_.pow_(1/p)
89
+ else:
90
+ if p == 1: return tensors_.abs()
91
+ if p == 2: return tensors_.sqrt()
92
+ return tensors_.pow(1/p)
93
+
94
+
95
+ def ema_(
96
+ tensors: TensorList,
97
+ exp_avg_: TensorList,
98
+ beta: float | NumberList,
99
+ dampening: float | NumberList = 0,
100
+ lerp: bool = True,
101
+ ):
102
+ """
103
+ Updates `exp_avg_` with EMA of `tensors`.
104
+
105
+ Returns `exp_avg_`.
106
+ """
107
+ tensors.lazy_mul_(1 - dampening)
108
+ if lerp: return exp_avg_.lerp_(tensors, (1 - beta))
109
+ return exp_avg_.mul_(beta).add_(tensors)
110
+
111
+ def ema_sq_(
112
+ tensors: TensorList,
113
+ exp_avg_sq_: TensorList,
114
+ beta: float | NumberList,
115
+ max_exp_avg_sq_: TensorList | None,
116
+ pow: float = 2,
117
+ ):
118
+ """
119
+ Updates `exp_avg_sq_` with EMA of squared `tensors`, if `max_exp_avg_sq_` is not None, updates it with maximum of EMA.
120
+
121
+ Returns `exp_avg_sq_` or `max_exp_avg_sq_`.
122
+ """
123
+ lerp_power_(tensors=tensors, exp_avg_pow_=exp_avg_sq_,beta=beta,pow=pow)
124
+
125
+ # AMSGrad
126
+ if max_exp_avg_sq_ is not None:
127
+ max_exp_avg_sq_.maximum_(exp_avg_sq_)
128
+ exp_avg_sq_ = max_exp_avg_sq_
129
+
130
+ return exp_avg_sq_
131
+
132
+ def sqrt_ema_sq_(
133
+ tensors: TensorList,
134
+ exp_avg_sq_: TensorList,
135
+ beta: float | NumberList,
136
+ max_exp_avg_sq_: TensorList | None,
137
+ debiased: bool,
138
+ step: int,
139
+ pow: float = 2,
140
+ ema_sq_fn: Callable = ema_sq_,
141
+ ):
142
+ """
143
+ Updates `exp_avg_sq_` with EMA of squared `tensors` and calculates it's square root,
144
+ with optional AMSGrad and debiasing.
145
+
146
+ Returns new tensors.
147
+ """
148
+ exp_avg_sq_=ema_sq_fn(
149
+ tensors=tensors,
150
+ exp_avg_sq_=exp_avg_sq_,
151
+ beta=beta,
152
+ max_exp_avg_sq_=max_exp_avg_sq_,
153
+ pow=pow,
154
+ )
155
+
156
+ sqrt_exp_avg_sq = root(exp_avg_sq_, pow, inplace=False)
157
+
158
+ if debiased: sqrt_exp_avg_sq = debias_second_momentum(sqrt_exp_avg_sq, step=step, beta=beta, pow=pow, inplace=True)
159
+ return sqrt_exp_avg_sq
160
+
161
+
162
+ def centered_ema_sq_(tensors: TensorList, exp_avg_: TensorList, exp_avg_sq_: TensorList,
163
+ beta: float | NumberList, max_exp_avg_sq_: TensorList | None = None, pow:float=2):
164
+ """
165
+ Updates `exp_avg_` and `exp_avg_sq_` with EMA of `tensors` and squared `tensors`,
166
+ centers `exp_avg_sq_` by subtracting `exp_avg_` squared.
167
+
168
+ Returns `max_exp_avg_sq_` or new tensors.
169
+ """
170
+ exp_avg_sq_ = ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta, max_exp_avg_sq_=max_exp_avg_sq_, pow=pow)
171
+ exp_avg_.lerp_(tensors, 1-beta)
172
+ exp_avg_sq_ = exp_avg_sq_.addcmul(exp_avg_, exp_avg_, value=-1)
173
+
174
+ # AMSGrad
175
+ if max_exp_avg_sq_ is not None:
176
+ max_exp_avg_sq_.maximum_(exp_avg_sq_)
177
+ exp_avg_sq_ = max_exp_avg_sq_
178
+
179
+ return exp_avg_sq_
180
+
181
+ def sqrt_centered_ema_sq_(
182
+ tensors: TensorList,
183
+ exp_avg_: TensorList,
184
+ exp_avg_sq_: TensorList,
185
+ max_exp_avg_sq_: TensorList | None,
186
+ beta: float | NumberList,
187
+ debiased: bool,
188
+ step: int,
189
+ pow: float = 2,
190
+ ):
191
+ """
192
+ Updates `exp_avg_` and `exp_avg_sq_` with EMA of `tensors` and squared `tensors`,
193
+ centers `exp_avg_sq_` by subtracting `exp_avg_` squared. Calculates it's square root,
194
+ with optional AMSGrad and debiasing.
195
+
196
+ Returns new tensors.
197
+ """
198
+ return sqrt_ema_sq_(
199
+ tensors=tensors,
200
+ exp_avg_sq_=exp_avg_sq_,
201
+ beta=beta,
202
+ max_exp_avg_sq_=max_exp_avg_sq_,
203
+ debiased=debiased,
204
+ step=step,
205
+ pow=pow,
206
+ ema_sq_fn=lambda *a, **kw: centered_ema_sq_(*a, **kw, exp_avg_=exp_avg_)
207
+ )
208
+
209
+
@@ -0,0 +1,4 @@
1
+ from .grad_approximator import GradApproximator, GradTarget
2
+ from .fdm import FDM
3
+ from .rfdm import RandomizedFDM, MeZO, SPSA, RDSA, GaussianSmoothing
4
+ from .forward_gradient import ForwardGradient
@@ -0,0 +1,120 @@
1
+ from collections.abc import Callable
2
+ from typing import Any
3
+
4
+ import torch
5
+
6
+ from ...utils import TensorList
7
+ from .grad_approximator import GradApproximator, GradTarget, _FD_Formula
8
+
9
+
10
+ def _forward2(closure: Callable[..., float], param:torch.Tensor, idx: int, h, v_0: float | None):
11
+ if v_0 is None: v_0 = closure(False)
12
+ assert param.ndim == 1
13
+ param[idx] += h
14
+ v_plus = closure(False)
15
+ param[idx] -= h
16
+ return v_0, v_0, (v_plus - v_0) / h # (loss, loss_approx, grad)
17
+
18
+ def _backward2(closure: Callable[..., float], param:torch.Tensor, idx: int, h, v_0: float | None):
19
+ if v_0 is None: v_0 = closure(False)
20
+ assert param.ndim == 1
21
+ param[idx] -= h
22
+ v_minus = closure(False)
23
+ param[idx] += h
24
+ return v_0, v_0, (v_0 - v_minus) / h
25
+
26
+ def _central2(closure: Callable[..., float], param:torch.Tensor, idx: int, h, v_0: Any):
27
+ assert param.ndim == 1
28
+ param[idx] += h
29
+ v_plus = closure(False)
30
+
31
+ param[idx] -= h * 2
32
+ v_minus = closure(False)
33
+
34
+ param[idx] += h
35
+ return v_0, v_plus, (v_plus - v_minus) / (2 * h)
36
+
37
+ def _forward3(closure: Callable[..., float], param:torch.Tensor, idx: int, h, v_0: float | None):
38
+ if v_0 is None: v_0 = closure(False)
39
+ assert param.ndim == 1
40
+ param[idx] += h
41
+ v_plus1 = closure(False)
42
+
43
+ param[idx] += h
44
+ v_plus2 = closure(False)
45
+
46
+ param[idx] -= 2 * h
47
+ return v_0, v_0, (-3*v_0 + 4*v_plus1 - v_plus2) / (2 * h)
48
+
49
+ def _backward3(closure: Callable[..., float], param:torch.Tensor, idx: int, h, v_0: float | None):
50
+ if v_0 is None: v_0 = closure(False)
51
+ assert param.ndim == 1
52
+ param[idx] -= h
53
+ v_minus1 = closure(False)
54
+
55
+ param[idx] -= h
56
+ v_minus2 = closure(False)
57
+
58
+ param[idx] += 2 * h
59
+ return v_0, v_0, (v_minus2 - 4*v_minus1 + 3*v_0) / (2 * h)
60
+
61
+ def _central4(closure: Callable[..., float], param:torch.Tensor, idx: int, h, v_0: Any):
62
+ assert param.ndim == 1
63
+
64
+ param[idx] += h
65
+ v_plus1 = closure(False)
66
+
67
+ param[idx] += h
68
+ v_plus2 = closure(False)
69
+
70
+ param[idx] -= 3 * h
71
+ v_minus1 = closure(False)
72
+
73
+ param[idx] -= h
74
+ v_minus2 = closure(False)
75
+
76
+ param[idx] += 2 * h
77
+ return v_0, v_plus1, (v_minus2 - 8*v_minus1 + 8*v_plus1 - v_plus2) / (12 * h)
78
+
79
+ _FD_FUNCS = {
80
+ "forward2": _forward2,
81
+ "backward2": _backward2,
82
+ "central2": _central2,
83
+ "central3": _central2, # they are the same
84
+ "forward3": _forward3,
85
+ "backward3": _backward3,
86
+ "central4": _central4,
87
+ }
88
+
89
+
90
+ class FDM(GradApproximator):
91
+ """Approximate gradients via finite difference method
92
+
93
+ Args:
94
+ h (float, optional): magnitude of parameter perturbation. Defaults to 1e-3.
95
+ formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
96
+ target (GradTarget, optional): what to set on vars. Defaults to 'closure'.
97
+ """
98
+ def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central2', target: GradTarget = 'closure'):
99
+ defaults = dict(h=h, formula=formula)
100
+ super().__init__(defaults, target=target)
101
+
102
+ @torch.no_grad
103
+ def approximate(self, closure, params, loss, vars):
104
+ grads = []
105
+ loss_approx = None
106
+
107
+ for p in params:
108
+ g = torch.zeros_like(p)
109
+ grads.append(g)
110
+
111
+ settings = self.settings[p]
112
+ h = settings['h']
113
+ fd_fn = _FD_FUNCS[settings['formula']]
114
+
115
+ p_flat = p.view(-1); g_flat = g.view(-1)
116
+ for i in range(len(p_flat)):
117
+ loss, loss_approx, d = fd_fn(closure=closure, param=p_flat, idx=i, h=h, v_0=loss)
118
+ g_flat[i] = d
119
+
120
+ return grads, loss, loss_approx
@@ -0,0 +1,81 @@
1
+ from collections.abc import Callable
2
+ from functools import partial
3
+ from typing import Any, Literal
4
+
5
+ import torch
6
+
7
+ from ...utils import Distributions, NumberList, TensorList, generic_eq
8
+ from ...utils.derivatives import jvp, jvp_fd_central, jvp_fd_forward
9
+ from .grad_approximator import GradApproximator, GradTarget
10
+ from .rfdm import RandomizedFDM
11
+
12
+
13
+ class ForwardGradient(RandomizedFDM):
14
+ """Forward gradient method, same as randomized finite difference but directional derivative is estimated via autograd (as jacobian vector product)
15
+
16
+ Args:
17
+ n_samples (int, optional): number of random gradient samples. Defaults to 1.
18
+ distribution (Distributions, optional): distribution for random gradient samples. Defaults to "gaussian".
19
+ beta (float, optional):
20
+ if not 0, acts as momentum on gradient samples, making the subspace spanned by them change slowly. Defaults to 0.
21
+ pre_generate (bool, optional):
22
+ whether to pre-generate gradient samples before each step. Defaults to True.
23
+ jvp_method (str, optional):
24
+ how to calculate jacobian vector product, note that with `forward` and 'central' this is identical to randomized finite difference. Defaults to 'autograd'.
25
+ h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
26
+ target (GradTarget, optional): what to set on vars. Defaults to "closure".
27
+ """
28
+ PRE_MULTIPLY_BY_H = False
29
+ def __init__(
30
+ self,
31
+ n_samples: int = 1,
32
+ distribution: Distributions = "gaussian",
33
+ beta: float = 0,
34
+ pre_generate = True,
35
+ jvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
36
+ h: float = 1e-3,
37
+ target: GradTarget = "closure",
38
+ seed: int | None | torch.Generator = None,
39
+ ):
40
+ super().__init__(h=h, n_samples=n_samples, distribution=distribution, beta=beta, target=target, pre_generate=pre_generate, seed=seed)
41
+ self.defaults['jvp_method'] = jvp_method
42
+
43
+ @torch.no_grad
44
+ def approximate(self, closure, params, loss, vars):
45
+ params = TensorList(params)
46
+ loss_approx = None
47
+
48
+ settings = self.settings[params[0]]
49
+ n_samples = settings['n_samples']
50
+ jvp_method = settings['jvp_method']
51
+ h = settings['h']
52
+ distribution = settings['distribution']
53
+ default = [None]*n_samples
54
+ perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
55
+ generator = self._get_generator(settings['seed'], params)
56
+
57
+ grad = None
58
+ for i in range(n_samples):
59
+ prt = perturbations[i]
60
+ if prt[0] is None: prt = params.sample_like(distribution=distribution, generator=generator)
61
+ else: prt = TensorList(prt)
62
+
63
+ if jvp_method == 'autograd':
64
+ with torch.enable_grad():
65
+ loss, d = jvp(partial(closure, False), params=params, tangent=prt)
66
+
67
+ elif jvp_method == 'forward':
68
+ loss, d = jvp_fd_forward(partial(closure, False), params=params, tangent=prt, v_0=loss, normalize=True, h=h)
69
+
70
+ elif jvp_method == 'central':
71
+ loss_approx, d = jvp_fd_central(partial(closure, False), params=params, tangent=prt, normalize=True, h=h)
72
+
73
+ else: raise ValueError(jvp_method)
74
+
75
+ if grad is None: grad = prt * d
76
+ else: grad += prt * d
77
+
78
+ assert grad is not None
79
+ if n_samples > 1: grad.div_(n_samples)
80
+ return grad, loss, loss_approx
81
+