torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,225 +0,0 @@
1
- import itertools
2
- import math
3
- import warnings
4
- from collections.abc import Callable
5
- from contextlib import nullcontext
6
- from functools import partial
7
- from typing import Any, Literal
8
-
9
- import numpy as np
10
- import scipy.optimize
11
- import torch
12
-
13
- from ...core import Chainable, Module, apply_transform
14
- from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
15
- from ...utils.derivatives import (
16
- hessian_list_to_mat,
17
- jacobian_wrt,
18
- hvp,
19
- )
20
-
21
- def _poly_eval_diag(s: np.ndarray, c, derivatives):
22
- val = float(c) + (derivatives[0] * s).sum(-1)
23
-
24
- if len(derivatives) > 1:
25
- for i, d_diag in enumerate(derivatives[1:], 2):
26
- val += (d_diag * (s**i)).sum(-1) / math.factorial(i)
27
-
28
- return val
29
-
30
- def _proximal_poly_v_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
31
- """Computes the value of the proximal polynomial approximation."""
32
- if x.ndim == 2: x = x.T
33
- s = x - x0
34
-
35
- val = _poly_eval_diag(s, c, derivatives)
36
-
37
- penalty = 0
38
- if prox != 0:
39
- penalty = (prox / 2) * (s**2).sum(-1)
40
-
41
- return val + penalty
42
-
43
- def _proximal_poly_g_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
44
- """Computes the gradient of the proximal polynomial approximation."""
45
- s = x - x0
46
-
47
- g = derivatives[0].copy()
48
-
49
- if len(derivatives) > 1:
50
- for i, d_diag in enumerate(derivatives[1:], 2):
51
- g += d_diag * (s**(i - 1)) / math.factorial(i - 1)
52
-
53
- if prox != 0:
54
- g += prox * s
55
-
56
- return g
57
-
58
- def _proximal_poly_H_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
59
- """Computes the Hessian of the proximal polynomial approximation."""
60
- s = x - x0
61
- n = x.shape[0]
62
-
63
- if len(derivatives) < 2:
64
- H_diag = np.zeros(n, dtype=s.dtype)
65
- else:
66
- H_diag = derivatives[1].copy()
67
-
68
- if len(derivatives) > 2:
69
- for i, d_diag in enumerate(derivatives[2:], 3):
70
- H_diag += d_diag * (s**(i - 2)) / math.factorial(i - 2)
71
-
72
- if prox != 0:
73
- H_diag += prox
74
-
75
- return np.diag(H_diag)
76
-
77
- def _poly_minimize(trust_region, prox, de_iters: Any, c, x: torch.Tensor, derivatives):
78
- derivatives = [T.detach().cpu().numpy().astype(np.float64) for T in derivatives]
79
- x0 = x.detach().cpu().numpy().astype(np.float64) # taylor series center
80
- bounds = None
81
- if trust_region is not None: bounds = list(zip(x0 - trust_region, x0 + trust_region))
82
-
83
- # if len(derivatives) is 1, only gradient is available, I use that to test proximal penalty and bounds
84
- if bounds is None:
85
- if len(derivatives) == 1: method = 'bfgs'
86
- else: method = 'trust-exact'
87
- else:
88
- if len(derivatives) == 1: method = 'l-bfgs-b'
89
- else: method = 'trust-constr'
90
-
91
- x_init = x0.copy()
92
- v0 = _proximal_poly_v_diag(x0, c, prox, x0, derivatives)
93
- if de_iters is not None and de_iters != 0:
94
- if de_iters == -1: de_iters = None # let scipy decide
95
- res = scipy.optimize.differential_evolution(
96
- _proximal_poly_v_diag,
97
- bounds if bounds is not None else list(zip(x0 - 10, x0 + 10)),
98
- args=(c, prox, x0.copy(), derivatives),
99
- maxiter=de_iters,
100
- vectorized=True,
101
- )
102
- if res.fun < v0: x_init = res.x
103
-
104
- res = scipy.optimize.minimize(
105
- _proximal_poly_v_diag,
106
- x_init,
107
- method=method,
108
- args=(c, prox, x0.copy(), derivatives),
109
- jac=_proximal_poly_g_diag,
110
- hess=_proximal_poly_H_diag,
111
- bounds=bounds
112
- )
113
-
114
- return torch.from_numpy(res.x).to(x), res.fun
115
-
116
-
117
-
118
- class DiagonalHigherOrderNewton(Module):
119
- """
120
- Hvp with ones doesn't give you the diagonal unless derivatives are diagonal, but somehow it still works,
121
- except it doesn't work in all cases except ones where it works.
122
- """
123
- def __init__(
124
- self,
125
- order: int = 4,
126
- trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
127
- increase: float = 1.5,
128
- decrease: float = 0.75,
129
- trust_init: float | None = None,
130
- trust_tol: float = 1,
131
- de_iters: int | None = None,
132
- vectorize: bool = True,
133
- ):
134
- if trust_init is None:
135
- if trust_method == 'bounds': trust_init = 1
136
- else: trust_init = 0.1
137
-
138
- defaults = dict(order=order, trust_method=trust_method, increase=increase, decrease=decrease, trust_tol=trust_tol, trust_init=trust_init, vectorize=vectorize, de_iters=de_iters)
139
- super().__init__(defaults)
140
-
141
- @torch.no_grad
142
- def step(self, var):
143
- params = TensorList(var.params)
144
- closure = var.closure
145
- if closure is None: raise RuntimeError('NewtonCG requires closure')
146
-
147
- settings = self.settings[params[0]]
148
- order = settings['order']
149
- increase = settings['increase']
150
- decrease = settings['decrease']
151
- trust_tol = settings['trust_tol']
152
- trust_init = settings['trust_init']
153
- trust_method = settings['trust_method']
154
- de_iters = settings['de_iters']
155
-
156
- trust_value = self.global_state.get('trust_value', trust_init)
157
-
158
-
159
- # ------------------------ calculate grad and hessian ------------------------ #
160
- with torch.enable_grad():
161
- loss = var.loss = var.loss_approx = closure(False)
162
-
163
- g = torch.autograd.grad(loss, params, create_graph=True)
164
- var.grad = list(g)
165
-
166
- derivatives = [g]
167
- T = g # current derivatives tensor diagonal
168
- ones = [torch.ones_like(t) for t in g]
169
-
170
- # get all derivatives up to order
171
- for o in range(2, order + 1):
172
- T = hvp(params, T, ones, create_graph=o != order)
173
- derivatives.append(T)
174
-
175
- x0 = torch.cat([p.ravel() for p in params])
176
-
177
- if trust_method is None: trust_method = 'none'
178
- else: trust_method = trust_method.lower()
179
-
180
- if trust_method == 'none':
181
- trust_region = None
182
- prox = 0
183
-
184
- elif trust_method == 'bounds':
185
- trust_region = trust_value
186
- prox = 0
187
-
188
- elif trust_method == 'proximal':
189
- trust_region = None
190
- prox = 1 / trust_value
191
-
192
- else:
193
- raise ValueError(trust_method)
194
-
195
- x_star, expected_loss = _poly_minimize(
196
- trust_region=trust_region,
197
- prox=prox,
198
- de_iters=de_iters,
199
- c=loss.item(),
200
- x=x0,
201
- derivatives=[torch.cat([t.ravel() for t in d]) for d in derivatives],
202
- )
203
-
204
- # trust region
205
- if trust_method != 'none':
206
- expected_reduction = loss - expected_loss
207
-
208
- vec_to_tensors_(x_star, params)
209
- loss_star = closure(False)
210
- vec_to_tensors_(x0, params)
211
- reduction = loss - loss_star
212
-
213
- # failed step
214
- if reduction <= 0:
215
- x_star = x0
216
- self.global_state['trust_value'] = trust_value * decrease
217
-
218
- # very good step
219
- elif expected_reduction / reduction <= trust_tol:
220
- self.global_state['trust_value'] = trust_value * increase
221
-
222
- difference = vec_to_tensors(x0 - x_star, params)
223
- var.update = list(difference)
224
- return var
225
-
@@ -1,163 +0,0 @@
1
- from operator import itemgetter
2
-
3
- import torch
4
-
5
- from ...core import Chainable, Transform
6
- from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
- from ..optimizers.soap import (
8
- update_soap_covariances_,
9
- get_orthogonal_matrix,
10
- get_orthogonal_matrix_QR,
11
- project,
12
- project_back,
13
- )
14
-
15
- class SOAPY(Transform):
16
- """Adam but uses scaled gradient differences for GGᵀ. Please note that this is experimental and isn't guaranteed to work.
17
-
18
- New args:
19
- scale_by_s - whether to scale gradient differences by parameter differences
20
- y_to_ema2 - whether to use gradient differences for exponential moving average too
21
- """
22
- def __init__(
23
- self,
24
- beta1: float = 0.95,
25
- beta2: float = 0.95,
26
- shampoo_beta: float | None = 0.95,
27
- precond_freq: int = 10,
28
- merge_small: bool = True,
29
- max_dim: int = 2_000,
30
- precondition_1d: bool = True,
31
- eps: float = 1e-8,
32
- decay: float | None = None,
33
- alpha: float = 1,
34
- bias_correction: bool = True,
35
- scale_by_s: bool = True,
36
- y_to_ema2: bool = False,
37
- ):
38
- defaults = dict(
39
- beta1=beta1,
40
- beta2=beta2,
41
- shampoo_beta=shampoo_beta,
42
- precond_freq=precond_freq,
43
- merge_small=merge_small,
44
- max_dim=max_dim,
45
- precondition_1d=precondition_1d,
46
- eps=eps,
47
- decay=decay,
48
- bias_correction=bias_correction,
49
- alpha=alpha,
50
- scale_by_s=scale_by_s,
51
- y_to_ema2=y_to_ema2,
52
- )
53
- super().__init__(defaults, uses_grad=False)
54
-
55
- @torch.no_grad
56
- def apply(self, tensors, params, grads, loss, states, settings):
57
- updates = []
58
- # update preconditioners
59
- for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
60
- beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
61
- 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(setting)
62
- scale_by_s = setting['scale_by_s']
63
- y_to_ema2 = setting['y_to_ema2']
64
-
65
- if merge_small:
66
- t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
67
-
68
- if 'g_prev' not in state:
69
- state['p_prev'] = p.clone()
70
- state['g_prev'] = t.clone()
71
- updates.append(tensors[i].clip(-0.1,0.1))
72
- continue
73
-
74
- p_prev = state['p_prev']
75
- g_prev = state['g_prev']
76
- s = p - p_prev
77
- y = t - g_prev
78
- if scale_by_s: y /= torch.linalg.norm(s).clip(min=1e-8) # pylint:disable=not-callable
79
-
80
- state['p_prev'].copy_(p)
81
- state['g_prev'].copy_(t)
82
-
83
- # initialize state on 1st step
84
- if 'GG' not in state:
85
- state["exp_avg"] = torch.zeros_like(t)
86
- if y_to_ema2: state["exp_avg_sq"] = torch.ones_like(t)
87
- else: state["exp_avg_sq"] = torch.zeros_like(t)
88
-
89
- if not precondition_1d and t.ndim <= 1:
90
- state['GG'] = []
91
-
92
- else:
93
- state['GG'] = [torch.zeros(sh, sh, dtype=t.dtype, device=t.device) if 1<sh<max_dim else None for sh in t.shape]
94
-
95
- # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
96
- if len([i is not None for i in state['GG']]) == 0:
97
- state['GG'] = None
98
-
99
- if state['GG'] is not None:
100
- update_soap_covariances_(y, GGs_=state['GG'], beta=shampoo_beta)
101
- state['Q'] = get_orthogonal_matrix(state['GG'])
102
-
103
- state['step'] = 0
104
- updates.append(tensors[i].clip(-0.1,0.1))
105
- continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
106
- # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
107
-
108
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
109
- # i.e. projecting to the eigenbases of matrices in state['GG']
110
- z_projected = None
111
- if state['GG'] is not None:
112
- if y_to_ema2: z_projected = project(y, state['Q'])
113
- else: z_projected = project(t, state['Q'])
114
-
115
- # exponential moving averages
116
- # this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
117
- exp_avg: torch.Tensor = state["exp_avg"]
118
- exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
119
-
120
- exp_avg.lerp_(t, 1-beta1)
121
-
122
- if z_projected is None:
123
- if y_to_ema2: exp_avg_sq.mul_(beta2).addcmul_(y, y, value=1-beta2)
124
- else: exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
125
- else:
126
- exp_avg_sq.mul_(beta2).addcmul_(z_projected, z_projected, value=1-beta2)
127
-
128
- # project exponential moving averages if they are accumulated unprojected
129
- exp_avg_projected = exp_avg
130
- if z_projected is not None:
131
- exp_avg_projected = project(exp_avg, state['Q'])
132
-
133
- exp_avg_sq_projected = exp_avg_sq
134
-
135
- denom = exp_avg_sq_projected.sqrt().add_(eps)
136
- # print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
137
-
138
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
139
- # to the original space
140
- update = exp_avg_projected / denom
141
- if z_projected is not None:
142
- update = project_back(update, state["Q"])
143
-
144
- if setting['bias_correction']:
145
- bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
146
- bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
147
- update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
148
- elif alpha is not None:
149
- update *= alpha
150
-
151
- if merge_small:
152
- update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
153
-
154
- updates.append(update)
155
- state["step"] += 1
156
-
157
- # Update is done after the gradient step to avoid using current gradients in the projection.
158
- if state['GG'] is not None:
159
- update_soap_covariances_(y, state['GG'], shampoo_beta)
160
- if state['step'] % setting['precond_freq'] == 0:
161
- state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
162
-
163
- return updates
@@ -1,111 +0,0 @@
1
- # idea https://arxiv.org/pdf/2212.09841
2
- import warnings
3
- from collections.abc import Callable
4
- from functools import partial
5
- from typing import Literal
6
-
7
- import torch
8
-
9
- from ...core import Chainable, Module, apply_transform
10
- from ...utils import TensorList, vec_to_tensors
11
- from ...utils.derivatives import (
12
- hessian_list_to_mat,
13
- hessian_mat,
14
- hvp,
15
- hvp_fd_central,
16
- hvp_fd_forward,
17
- jacobian_and_hessian_wrt,
18
- )
19
-
20
-
21
- class StructuredNewton(Module):
22
- """TODO. Please note that this is experimental and isn't guaranteed to work.
23
- Args:
24
- structure (str, optional): structure.
25
- reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
26
- hvp_method (str):
27
- how to calculate hvp_method. Defaults to "autograd".
28
- inner (Chainable | None, optional): inner modules. Defaults to None.
29
-
30
- """
31
- def __init__(
32
- self,
33
- structure: Literal[
34
- "diagonal",
35
- "diagonal1",
36
- "diagonal_abs",
37
- "tridiagonal",
38
- "circulant",
39
- "toeplitz",
40
- "toeplitz_like",
41
- "hankel",
42
- "rank1",
43
- "rank2", # any rank
44
- ]
45
- | str = "diagonal",
46
- reg: float = 1e-6,
47
- hvp_method: Literal["autograd", "forward", "central"] = "autograd",
48
- h: float = 1e-3,
49
- inner: Chainable | None = None,
50
- ):
51
- defaults = dict(reg=reg, hvp_method=hvp_method, structure=structure, h=h)
52
- super().__init__(defaults)
53
-
54
- if inner is not None:
55
- self.set_child('inner', inner)
56
-
57
- @torch.no_grad
58
- def step(self, var):
59
- params = TensorList(var.params)
60
- closure = var.closure
61
- if closure is None: raise RuntimeError('NewtonCG requires closure')
62
-
63
- settings = self.settings[params[0]]
64
- reg = settings['reg']
65
- hvp_method = settings['hvp_method']
66
- structure = settings['structure']
67
- h = settings['h']
68
-
69
- # ------------------------ calculate grad and hessian ------------------------ #
70
- if hvp_method == 'autograd':
71
- grad = var.get_grad(create_graph=True)
72
- def Hvp_fn1(x):
73
- return hvp(params, grad, x, retain_graph=True)
74
- Hvp_fn = Hvp_fn1
75
-
76
- elif hvp_method == 'forward':
77
- grad = var.get_grad()
78
- def Hvp_fn2(x):
79
- return hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1]
80
- Hvp_fn = Hvp_fn2
81
-
82
- elif hvp_method == 'central':
83
- grad = var.get_grad()
84
- def Hvp_fn3(x):
85
- return hvp_fd_central(closure, params, x, h=h, normalize=True)[1]
86
- Hvp_fn = Hvp_fn3
87
-
88
- else: raise ValueError(hvp_method)
89
-
90
- # -------------------------------- inner step -------------------------------- #
91
- update = var.get_update()
92
- if 'inner' in self.children:
93
- update = apply_transform(self.children['inner'], update, params=params, grads=grad, var=var)
94
-
95
- # hessian
96
- if structure.startswith('diagonal'):
97
- H = Hvp_fn([torch.ones_like(p) for p in params])
98
- if structure == 'diagonal1': torch._foreach_clamp_min_(H, 1)
99
- if structure == 'diagonal_abs': torch._foreach_abs_(H)
100
- torch._foreach_add_(H, reg)
101
- torch._foreach_div_(update, H)
102
- var.update = update
103
- return var
104
-
105
- # hessian
106
- raise NotImplementedError(structure)
107
-
108
-
109
-
110
-
111
-
@@ -1,2 +0,0 @@
1
- from .lr import LR, StepSize, Warmup
2
- from .adaptive import PolyakStepSize, RandomStepSize
@@ -1,93 +0,0 @@
1
- """Various step size strategies"""
2
- import random
3
- from typing import Any
4
- from operator import itemgetter
5
- import torch
6
-
7
- from ...core import Transform
8
- from ...utils import TensorList, NumberList, unpack_dicts
9
-
10
-
11
- class PolyakStepSize(Transform):
12
- """Polyak's step-size method.
13
-
14
- Args:
15
- max (float | None, optional): maximum possible step size. Defaults to None.
16
- min_obj_value (int, optional):
17
- (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
18
- use_grad (bool, optional):
19
- if True, uses dot product of update and gradient to compute the step size.
20
- Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
21
- Defaults to True.
22
- parameterwise (bool, optional):
23
- if True, calculate Polyak step-size for each parameter separately,
24
- if False calculate one global step size for all parameters. Defaults to False.
25
- alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
26
- """
27
- def __init__(self, max: float | None = None, min_obj_value: float = 0, use_grad=True, parameterwise=False, alpha: float = 1):
28
-
29
- defaults = dict(alpha=alpha, max=max, min_obj_value=min_obj_value, use_grad=use_grad, parameterwise=parameterwise)
30
- super().__init__(defaults, uses_grad=use_grad)
31
-
32
- @torch.no_grad
33
- def apply(self, tensors, params, grads, loss, states, settings):
34
- assert grads is not None
35
- tensors = TensorList(tensors)
36
- grads = TensorList(grads)
37
- alpha = NumberList(s['alpha'] for s in settings)
38
-
39
- parameterwise, use_grad, max, min_obj_value = itemgetter('parameterwise', 'use_grad', 'max', 'min_obj_value')(settings[0])
40
-
41
- if use_grad: denom = tensors.dot(grads)
42
- else: denom = tensors.dot(tensors)
43
-
44
- if parameterwise:
45
- polyak_step_size: TensorList | Any = (loss - min_obj_value) / denom.where(denom!=0, 1)
46
- polyak_step_size = polyak_step_size.where(denom != 0, 0)
47
- if max is not None: polyak_step_size = polyak_step_size.clamp_max(max)
48
-
49
- else:
50
- if denom.abs() <= torch.finfo(denom.dtype).eps: polyak_step_size = 0 # converged
51
- else: polyak_step_size = (loss - min_obj_value) / denom
52
-
53
- if max is not None:
54
- if polyak_step_size > max: polyak_step_size = max
55
-
56
- tensors.mul_(alpha * polyak_step_size)
57
- return tensors
58
-
59
-
60
- class RandomStepSize(Transform):
61
- """Uses random global or layer-wise step size from `low` to `high`.
62
-
63
- Args:
64
- low (float, optional): minimum learning rate. Defaults to 0.
65
- high (float, optional): maximum learning rate. Defaults to 1.
66
- parameterwise (bool, optional):
67
- if True, generate random step size for each parameter separately,
68
- if False generate one global random step size. Defaults to False.
69
- """
70
- def __init__(self, low: float = 0, high: float = 1, parameterwise=False, seed:int|None=None):
71
- defaults = dict(low=low, high=high, parameterwise=parameterwise,seed=seed)
72
- super().__init__(defaults, uses_grad=False)
73
-
74
- @torch.no_grad
75
- def apply(self, tensors, params, grads, loss, states, settings):
76
- s = settings[0]
77
- parameterwise = s['parameterwise']
78
-
79
- seed = s['seed']
80
- if 'generator' not in self.global_state:
81
- self.global_state['generator'] = random.Random(seed)
82
- generator: random.Random = self.global_state['generator']
83
-
84
- if parameterwise:
85
- low, high = unpack_dicts(settings, 'low', 'high')
86
- lr = [generator.uniform(l, h) for l, h in zip(low, high)]
87
- else:
88
- low = s['low']
89
- high = s['high']
90
- lr = generator.uniform(low, high)
91
-
92
- torch._foreach_mul_(tensors, lr)
93
- return tensors
@@ -1,63 +0,0 @@
1
- """Learning rate"""
2
- import torch
3
-
4
- from ...core import Transform
5
- from ...utils import NumberList, TensorList, generic_eq, unpack_dicts
6
-
7
- def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
8
- """multiplies by lr if lr is not 1"""
9
- if generic_eq(lr, 1): return tensors
10
- if inplace: return tensors.mul_(lr)
11
- return tensors * lr
12
-
13
- class LR(Transform):
14
- """Learning rate. Adding this module also adds support for LR schedulers."""
15
- def __init__(self, lr: float):
16
- defaults=dict(lr=lr)
17
- super().__init__(defaults, uses_grad=False)
18
-
19
- @torch.no_grad
20
- def apply(self, tensors, params, grads, loss, states, settings):
21
- return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)
22
-
23
- class StepSize(Transform):
24
- """this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
25
- def __init__(self, step_size: float, key = 'step_size'):
26
- defaults={"key": key, key: step_size}
27
- super().__init__(defaults, uses_grad=False)
28
-
29
- @torch.no_grad
30
- def apply(self, tensors, params, grads, loss, states, settings):
31
- return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)
32
-
33
-
34
- def _warmup_lr(step: int, start_lr: float | NumberList, end_lr: float | NumberList, steps: float):
35
- """returns warm up lr scalar"""
36
- if step > steps: return end_lr
37
- return start_lr + (end_lr - start_lr) * (step / steps)
38
-
39
- class Warmup(Transform):
40
- """Learning rate warmup, linearly increases learning rate multiplier from :code:`start_lr` to :code:`end_lr` over :code:`steps` steps.
41
-
42
- Args:
43
- start_lr (_type_, optional): initial learning rate multiplier on first step. Defaults to 1e-5.
44
- end_lr (float, optional): learning rate multiplier at the end and after warmup. Defaults to 1.
45
- steps (int, optional): number of steps to perform warmup for. Defaults to 100.
46
- """
47
- def __init__(self, start_lr = 1e-5, end_lr:float = 1, steps = 100):
48
- defaults = dict(start_lr=start_lr,end_lr=end_lr, steps=steps)
49
- super().__init__(defaults, uses_grad=False)
50
-
51
- @torch.no_grad
52
- def apply(self, tensors, params, grads, loss, states, settings):
53
- start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
54
- num_steps = settings[0]['steps']
55
- step = self.global_state.get('step', 0)
56
-
57
- target = lazy_lr(
58
- TensorList(tensors),
59
- lr=_warmup_lr(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
60
- inplace=True
61
- )
62
- self.global_state['step'] = step + 1
63
- return target