torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +54 -21
  2. tests/test_tensorlist.py +2 -2
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +19 -129
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +12 -12
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +67 -17
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +12 -12
  78. torchzero/modules/quasi_newton/lsr1.py +11 -11
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +254 -47
  81. torchzero/modules/second_order/newton.py +32 -20
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +21 -21
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.9.dist-info/RECORD +0 -131
  107. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,88 @@
1
+ import itertools
2
+ import warnings
3
+ from collections.abc import Callable
4
+ from contextlib import nullcontext
5
+ from functools import partial
6
+ from typing import Literal
7
+
8
+ import torch
9
+
10
+ from ...core import Chainable, Module, apply_transform
11
+ from ...utils import TensorList, vec_to_tensors
12
+ from ...utils.derivatives import (
13
+ hessian_list_to_mat,
14
+ jacobian_wrt,
15
+ )
16
+ from ..second_order.newton import (
17
+ cholesky_solve,
18
+ eigh_solve,
19
+ least_squares_solve,
20
+ lu_solve,
21
+ )
22
+
23
+
24
+ class NewtonNewton(Module):
25
+ """
26
+ Method that I thought of and then it worked.
27
+
28
+ 1. Calculate newton step by solving Hx=g
29
+
30
+ 2. Calculate jacobian of x wrt parameters and call it H2
31
+
32
+ 3. Solve H2 x2 = x for x2.
33
+
34
+ 4. Optionally, repeat (if order is higher than 3.)
35
+
36
+ Memory is n^order. It tends to converge faster on convex functions, but can be unstable on non-convex. Orders higher than 3 are usually too unsable and have little benefit.
37
+ """
38
+ def __init__(
39
+ self,
40
+ reg: float = 1e-6,
41
+ order: int = 3,
42
+ search_negative: bool = False,
43
+ vectorize: bool = True,
44
+ eigval_tfm: Callable[[torch.Tensor], torch.Tensor] | None = None,
45
+ ):
46
+ defaults = dict(order=order, reg=reg, vectorize=vectorize, eigval_tfm=eigval_tfm, search_negative=search_negative)
47
+ super().__init__(defaults)
48
+
49
+ @torch.no_grad
50
+ def step(self, var):
51
+ params = TensorList(var.params)
52
+ closure = var.closure
53
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
54
+
55
+ settings = self.settings[params[0]]
56
+ reg = settings['reg']
57
+ vectorize = settings['vectorize']
58
+ order = settings['order']
59
+ search_negative = settings['search_negative']
60
+ eigval_tfm = settings['eigval_tfm']
61
+
62
+ # ------------------------ calculate grad and hessian ------------------------ #
63
+ with torch.enable_grad():
64
+ loss = var.loss = var.loss_approx = closure(False)
65
+ g_list = torch.autograd.grad(loss, params, create_graph=True)
66
+ var.grad = list(g_list)
67
+
68
+ xp = torch.cat([t.ravel() for t in g_list])
69
+ I = torch.eye(xp.numel(), dtype=xp.dtype, device=xp.device)
70
+
71
+ for o in range(2, order + 1):
72
+ is_last = o == order
73
+ H_list = jacobian_wrt([xp], params, create_graph=not is_last, batched=vectorize)
74
+ with torch.no_grad() if is_last else nullcontext():
75
+ H = hessian_list_to_mat(H_list)
76
+ if reg != 0: H = H + I * reg
77
+
78
+ x = None
79
+ if search_negative or (is_last and eigval_tfm is not None):
80
+ x = eigh_solve(H, xp, eigval_tfm, search_negative=search_negative)
81
+ if x is None: x = cholesky_solve(H, xp)
82
+ if x is None: x = lu_solve(H, xp)
83
+ if x is None: x = least_squares_solve(H, xp)
84
+ xp = x.squeeze()
85
+
86
+ var.update = vec_to_tensors(xp, params)
87
+ return var
88
+
@@ -1,30 +1,33 @@
1
1
  import torch
2
2
 
3
3
  from ...core import Target, Transform
4
- from ...utils import TensorList
4
+ from ...utils import TensorList, unpack_states, unpack_dicts
5
5
 
6
6
  class ReduceOutwardLR(Transform):
7
7
  """
8
8
  When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
9
9
 
10
10
  This means updates that move weights towards zero have higher learning rates.
11
+
12
+ A note on this is that it sounded good but its really bad in practice.
11
13
  """
12
14
  def __init__(self, mul = 0.5, use_grad=False, invert=False, target: Target = 'update'):
13
15
  defaults = dict(mul=mul, use_grad=use_grad, invert=invert)
14
16
  super().__init__(defaults, uses_grad=use_grad, target=target)
15
17
 
16
18
  @torch.no_grad
17
- def transform(self, tensors, params, grads, vars):
19
+ def apply(self, tensors, params, grads, loss, states, settings):
18
20
  params = TensorList(params)
19
21
  tensors = TensorList(tensors)
20
22
 
21
- mul = self.get_settings('mul', params=params)
22
- s = self.settings[params[0]]
23
+ mul = [s['mul'] for s in settings]
24
+ s = settings[0]
23
25
  use_grad = s['use_grad']
24
26
  invert = s['invert']
25
27
 
26
- if use_grad: cur = vars.get_grad()
28
+ if use_grad: cur = grads
27
29
  else: cur = tensors
30
+ assert cur is not None
28
31
 
29
32
  # mask of weights where sign matches with update sign (minus ascent sign), multiplied by `mul`.
30
33
  if invert: mask = (params * cur) > 0
@@ -2,147 +2,22 @@ from operator import itemgetter
2
2
 
3
3
  import torch
4
4
 
5
- from ...core import Chainable, Transform, apply
5
+ from ...core import Chainable, Transform
6
6
  from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
-
8
- @torch.no_grad
9
- def update_soap_covariances_(
10
- grad: torch.Tensor,
11
- GGs_: list[torch.Tensor | None],
12
- beta: float | None,
13
- ):
14
- for i, GG in enumerate(GGs_):
15
- if GG is None: continue
16
-
17
- axes = list(range(i)) + list(range(i + 1, grad.ndim)) # this works fine with 1d params
18
- if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
19
- else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
20
-
21
- @torch.no_grad
22
- def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
23
- """
24
- Projects the gradient to the eigenbases of the preconditioner.
25
- """
26
- for mat in Q:
27
- if mat is None: continue
28
- if len(mat) > 0:
29
- tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
30
- else:
31
- # I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
32
- permute_order = list(range(1, len(tensors.shape))) + [0]
33
- tensors = tensors.permute(permute_order)
34
-
35
- return tensors
36
-
37
- @torch.no_grad
38
- def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
39
- """
40
- Projects the gradient back to the original space.
41
- """
42
- for mat in Q:
43
- if mat is None: continue
44
- if len(mat) > 0:
45
- tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
46
- else:
47
- permute_order = list(range(1, len(tensors.shape))) + [0]
48
- tensors = tensors.permute(permute_order)
49
-
50
- return tensors
51
-
52
- # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
53
- @torch.no_grad
54
- def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
55
- """
56
- Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
57
- """
58
- matrix = []
59
- float_data = False
60
- original_type = original_device = None
61
- for m in mat:
62
- if m is None: continue
63
- if len(m) == 0:
64
- matrix.append([])
65
- continue
66
- if m.dtype != torch.float:
67
- original_type = m.dtype
68
- original_device = m.device
69
- matrix.append(m.float())
70
- else:
71
- float_data = True
72
- matrix.append(m)
73
-
74
- final = []
75
- for m in matrix:
76
- if len(m) == 0:
77
- final.append([])
78
- continue
79
- try:
80
- _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
81
- except Exception:
82
- _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
83
- Q = Q.to(m.dtype)
84
- Q = torch.flip(Q, [1])
85
-
86
- if not float_data:
87
- Q = Q.to(original_device).type(original_type)
88
- final.append(Q)
89
- return final
90
-
91
- # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
92
- @torch.no_grad
93
- def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
94
- """
95
- Computes the eigenbases of the preconditioner using one round of power iteration
96
- followed by torch.linalg.qr decomposition.
97
- """
98
- matrix = []
99
- orth_matrix = []
100
- float_data = False
101
- original_type = original_device = None
102
- for m,o in zip(GG, Q_list):
103
- if m is None: continue
104
- assert o is not None
105
-
106
- if len(m) == 0:
107
- matrix.append([])
108
- orth_matrix.append([])
109
- continue
110
- if m.data.dtype != torch.float:
111
- original_type = m.data.dtype
112
- original_device = m.data.device
113
- matrix.append(m.data.float())
114
- orth_matrix.append(o.data.float())
115
- else:
116
- float_data = True
117
- matrix.append(m.data.float())
118
- orth_matrix.append(o.data.float())
119
-
120
- final = []
121
- for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
122
- if len(m)==0:
123
- final.append([])
124
- continue
125
- est_eig = torch.diag(o.T @ m @ o)
126
- sort_idx = torch.argsort(est_eig, descending=True)
127
- exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
128
- o = o[:,sort_idx]
129
- power_iter = m @ o
130
- Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
131
-
132
- if not float_data:
133
- Q = Q.to(original_device).type(original_type)
134
- final.append(Q)
135
-
136
- return final, exp_avg_sq
7
+ from ..optimizers.soap import (
8
+ update_soap_covariances_,
9
+ get_orthogonal_matrix,
10
+ get_orthogonal_matrix_QR,
11
+ project,
12
+ project_back,
13
+ )
137
14
 
138
15
  class SOAPY(Transform):
139
- """SOAP but uses scaled gradient differences
140
-
141
- new args
142
-
143
- scale by s whether to scale gradient differences by parameter differences
16
+ """Adam but uses scaled gradient differences for GGᵀ. Please note that this is experimental and isn't guaranteed to work.
144
17
 
145
- y_to_ema2 whether to use gradient differences for exponential moving average too
18
+ New args:
19
+ scale_by_s - whether to scale gradient differences by parameter differences
20
+ y_to_ema2 - whether to use gradient differences for exponential moving average too
146
21
  """
147
22
  def __init__(
148
23
  self,
@@ -178,16 +53,14 @@ class SOAPY(Transform):
178
53
  super().__init__(defaults, uses_grad=False)
179
54
 
180
55
  @torch.no_grad
181
- def transform(self, tensors, params, grads, vars):
56
+ def apply(self, tensors, params, grads, loss, states, settings):
182
57
  updates = []
183
58
  # update preconditioners
184
- for i,(p,t) in enumerate(zip(params, tensors)):
185
- state = self.state[p]
186
- settings = self.settings[p]
59
+ for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
187
60
  beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
188
- 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(settings)
189
- scale_by_s = settings['scale_by_s']
190
- y_to_ema2 = settings['y_to_ema2']
61
+ 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(setting)
62
+ scale_by_s = setting['scale_by_s']
63
+ y_to_ema2 = setting['y_to_ema2']
191
64
 
192
65
  if merge_small:
193
66
  t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
@@ -268,7 +141,7 @@ class SOAPY(Transform):
268
141
  if z_projected is not None:
269
142
  update = project_back(update, state["Q"])
270
143
 
271
- if settings['bias_correction']:
144
+ if setting['bias_correction']:
272
145
  bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
273
146
  bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
274
147
  update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
@@ -284,7 +157,7 @@ class SOAPY(Transform):
284
157
  # Update is done after the gradient step to avoid using current gradients in the projection.
285
158
  if state['GG'] is not None:
286
159
  update_soap_covariances_(y, state['GG'], shampoo_beta)
287
- if state['step'] % settings['precond_freq'] == 0:
160
+ if state['step'] % setting['precond_freq'] == 0:
288
161
  state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
289
162
 
290
163
  return updates