torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -3,8 +3,10 @@ import warnings
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Chainable, Transform, apply_transform
6
+ from ...core import TensorTransform, Chainable
7
+ from ...utils import unpack_dicts, unpack_states, TensorList, NumberList
7
8
  from ...modules.adaptive.shampoo import _merge_small_dims, _unmerge_small_dims
9
+ from ...linalg import torch_linalg
8
10
 
9
11
  @torch.no_grad
10
12
  def update_soap_covariances_(
@@ -20,52 +22,48 @@ def update_soap_covariances_(
20
22
  else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
21
23
 
22
24
  @torch.no_grad
23
- def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
25
+ def project(tensor: torch.Tensor, Q: list[torch.Tensor | None]):
24
26
  """
25
27
  Projects the gradient to the eigenbases of the preconditioner.
26
28
  """
27
- for mat in Q:
28
- if mat is not None and len(mat) > 0:
29
- tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
29
+ for M in Q:
30
+ if M is not None:
31
+ tensor = torch.tensordot(tensor, M, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
30
32
  else:
31
- permute_order = list(range(1, len(tensors.shape))) + [0]
32
- tensors = tensors.permute(permute_order)
33
+ permute_order = list(range(1, len(tensor.shape))) + [0]
34
+ tensor = tensor.permute(permute_order)
33
35
 
34
- return tensors
36
+ return tensor
35
37
 
36
38
  @torch.no_grad
37
- def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
39
+ def project_back(tensor: torch.Tensor, Q: list[torch.Tensor| None]):
38
40
  """
39
41
  Projects the gradient back to the original space.
40
42
  """
41
- for mat in Q:
42
- if mat is not None and len(mat) > 0:
43
- tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
43
+ for M in Q:
44
+ if M is not None:
45
+ tensor = torch.tensordot(tensor, M, dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
44
46
  else:
45
- permute_order = list(range(1, len(tensors.shape))) + [0]
46
- tensors = tensors.permute(permute_order)
47
+ permute_order = list(range(1, len(tensor.shape))) + [0]
48
+ tensor = tensor.permute(permute_order)
47
49
 
48
- return tensors
50
+ return tensor
49
51
 
50
52
  # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
51
53
  @torch.no_grad
52
- def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
54
+ def get_orthogonal_matrix(mats: list[torch.Tensor | None]):
53
55
  """
54
56
  Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
55
57
  """
56
58
 
57
59
  final = []
58
- for m in mat:
60
+ for M in mats:
59
61
 
60
- if m is None or len(m) == 0:
61
- final.append([])
62
+ if M is None:
63
+ final.append(None)
62
64
  continue
63
65
 
64
- try:
65
- _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
66
- except torch.linalg.LinAlgError:
67
- _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
68
- Q = Q.to(m.dtype)
66
+ _, Q = torch_linalg.eigh(M + 1e-30 * torch.eye(M.shape[0], device=M.device), retry_float64=True)
69
67
 
70
68
  Q = torch.flip(Q, [1])
71
69
  final.append(Q)
@@ -78,30 +76,33 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
78
76
  """
79
77
  Computes the eigenbases of the preconditioner using one round of power iteration
80
78
  followed by torch.linalg.qr decomposition.
81
- """
79
+
80
+ Approximately modifies ``exp_avg_sq`` to be in the new eigenbases.
81
+ """
82
82
  final = []
83
83
 
84
- for ind, (m,o) in enumerate(zip(GG, Q_list)):
84
+ for ind, (M, O) in enumerate(zip(GG, Q_list)):
85
85
 
86
86
  # skip 1d or large dims
87
- if m is None or len(m) == 0:
88
- final.append([])
87
+ if M is None:
88
+ final.append(None)
89
89
  continue
90
- assert o is not None
91
90
 
92
- est_eig = torch.diag(o.T @ m @ o)
91
+ assert O is not None
92
+
93
+ est_eig = torch.diagonal(O.T @ M @ O)
93
94
  sort_idx = torch.argsort(est_eig, descending=True)
94
95
  exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
95
96
 
96
- power_iter = m @ o[:, sort_idx]
97
- Q, _ = torch.linalg.qr(power_iter.to(torch.float32)) # pylint:disable=not-callable
97
+ power_iter = M @ O[:, sort_idx]
98
+ Q, _ = torch_linalg.qr(power_iter.to(torch.float32), retry_float64=True)
98
99
  Q = Q.to(power_iter.dtype)
99
100
 
100
101
  final.append(Q)
101
102
 
102
103
  return final, exp_avg_sq
103
104
 
104
- class SOAP(Transform):
105
+ class SOAP(TensorTransform):
105
106
  """SOAP (ShampoO with Adam in the Preconditioner's eigenbasis from https://arxiv.org/abs/2409.11321).
106
107
 
107
108
  Args:
@@ -111,35 +112,42 @@ class SOAP(Transform):
111
112
  beta for covariance matrices accumulators. Can be None, then it just sums them like Adagrad (which works worse). Defaults to 0.95.
112
113
  precond_freq (int, optional): How often to update the preconditioner. Defaults to 10.
113
114
  merge_small (bool, optional): Whether to merge small dims. Defaults to True.
114
- max_dim (int, optional): Won't precondition dims larger than this. Defaults to 2_000.
115
+ max_dim (int, optional): Won't precondition dims larger than this. Defaults to 10_000.
115
116
  precondition_1d (bool, optional):
116
117
  Whether to precondition 1d params (SOAP paper sets this to False). Defaults to True.
117
118
  eps (float, optional):
118
119
  epsilon for dividing first momentum by second. Defaults to 1e-8.
119
- decay (float | None, optional):
120
- Decays covariance matrix accumulators, this may be useful if `shampoo_beta` is None. Defaults to None.
120
+ debias (bool, optional):
121
+ enables adam bias correction. Defaults to True.
122
+ proj_exp_avg (bool, optional):
123
+ if True, maintains exponential average of gradients (momentum) in projected space.
124
+ If False - in original space Defaults to True.
121
125
  alpha (float, optional):
122
126
  learning rate. Defaults to 1.
123
- bias_correction (bool, optional):
124
- enables adam bias correction. Defaults to True.
125
-
126
- Examples:
127
- SOAP:
128
-
129
- .. code-block:: python
130
-
131
- opt = tz.Modular(model.parameters(), tz.m.SOAP(), tz.m.LR(1e-3))
132
-
133
- Stabilized SOAP:
134
-
135
- .. code-block:: python
136
-
137
- opt = tz.Modular(
138
- model.parameters(),
139
- tz.m.SOAP(),
140
- tz.m.NormalizeByEMA(max_ema_growth=1.2),
141
- tz.m.LR(1e-2)
142
- )
127
+ inner (Chainable | None, optional):
128
+ output of this module is projected and Adam will run on it, but preconditioners are updated
129
+ from original gradients.
130
+
131
+ ### Examples:
132
+ SOAP:
133
+
134
+ ```python
135
+ opt = tz.Optimizer(
136
+ model.parameters(),
137
+ tz.m.SOAP(),
138
+ tz.m.LR(1e-3)
139
+ )
140
+ ```
141
+ Stabilized SOAP:
142
+
143
+ ```python
144
+ opt = tz.Optimizer(
145
+ model.parameters(),
146
+ tz.m.SOAP(),
147
+ tz.m.NormalizeByEMA(max_ema_growth=1.2),
148
+ tz.m.LR(1e-2)
149
+ )
150
+ ```
143
151
  """
144
152
  def __init__(
145
153
  self,
@@ -148,118 +156,174 @@ class SOAP(Transform):
148
156
  shampoo_beta: float | None = 0.95,
149
157
  precond_freq: int = 10,
150
158
  merge_small: bool = True,
151
- max_dim: int = 2_000,
159
+ max_dim: int = 4096,
152
160
  precondition_1d: bool = True,
153
161
  eps: float = 1e-8,
154
- decay: float | None = None,
162
+ debias: bool = True,
163
+ proj_exp_avg: bool = True,
155
164
  alpha: float = 1,
156
- bias_correction: bool = True,
165
+
166
+ inner: Chainable | None = None,
157
167
  ):
158
- defaults = dict(
159
- beta1=beta1,
160
- beta2=beta2,
161
- shampoo_beta=shampoo_beta,
162
- precond_freq=precond_freq,
163
- merge_small=merge_small,
164
- max_dim=max_dim,
165
- precondition_1d=precondition_1d,
166
- eps=eps,
167
- decay=decay,
168
- bias_correction=bias_correction,
169
- alpha=alpha,
170
- )
171
- super().__init__(defaults, uses_grad=False)
168
+ defaults = locals().copy()
169
+ del defaults['self'], defaults["inner"]
170
+
171
+ super().__init__(defaults)
172
+ self.set_child("inner", inner)
172
173
 
173
174
  @torch.no_grad
174
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
175
- updates = []
176
- # update preconditioners
177
- for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
178
- beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps,alpha = itemgetter(
179
- 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps','alpha')(setting)
175
+ def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
176
+ if setting["merge_small"]:
177
+ tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
178
+
179
+ state["exp_avg_proj"] = torch.zeros_like(tensor)
180
+ state["exp_avg_sq_proj"] = torch.zeros_like(tensor)
180
181
 
181
- if merge_small:
182
- t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
182
+ if tensor.ndim <= 1 and not setting["precondition_1d"]:
183
+ state['GG'] = []
183
184
 
184
- # initialize state on 1st step
185
- if 'GG' not in state:
186
- state["exp_avg"] = torch.zeros_like(t)
187
- state["exp_avg_sq_projected"] = torch.zeros_like(t)
185
+ else:
186
+ max_dim = setting["max_dim"]
187
+ state['GG'] = [
188
+ torch.zeros(s, s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
189
+ ]
188
190
 
189
- if not precondition_1d and t.ndim <= 1:
190
- state['GG'] = []
191
+ # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
192
+ if len([i is not None for i in state['GG']]) == 0:
193
+ state['GG'] = None
191
194
 
192
- else:
193
- state['GG'] = [torch.zeros(s, s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
195
+ # first covariance accumulation
196
+ if state['GG'] is not None:
197
+ update_soap_covariances_(tensor, GGs_=state['GG'], beta=setting["shampoo_beta"])
194
198
 
195
- # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
196
- if len([i is not None for i in state['GG']]) == 0:
197
- state['GG'] = None
199
+ # get projection matrix with first gradients with eigh
200
+ try: state['Q'] = get_orthogonal_matrix(state['GG'])
201
+ except torch.linalg.LinAlgError as e:
202
+ warnings.warn(f"torch.linalg.eigh raised an error when initializing SOAP Q matrices on 1st step, diagonal preconditioning will be used for this parameter. The error was:\n{e}")
203
+ state["GG"] = None
204
+
205
+ state['step'] = 0
206
+
207
+
208
+ # no update to avoid running merge_dims twice
209
+
210
+ @torch.no_grad
211
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
212
+ # note
213
+ # do not modify tensors in-place
214
+ # because they are used to update preconditioner at the end
215
+
216
+ steps = [s["step"] for s in states]
217
+ if any(s == 0 for s in steps):
218
+ # skip 1st update so to avoid using current gradient in the projection
219
+ # I scale it instead to avoid issues with further modules
220
+ for s in states: s["step"] += 1
221
+ return TensorList(tensors).clamp(-0.1, 0.1)
222
+ # return TensorList(tensors).zero_()
223
+
224
+
225
+ fs = settings[0]
226
+ merged = []
227
+ projected = []
228
+ # ---------------------------------- project --------------------------------- #
229
+
230
+ for tensor, state, setting in zip(tensors, states, settings):
231
+ if setting["merge_small"]:
232
+ tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
233
+
234
+ merged.append(tensor)
198
235
 
199
- if state['GG'] is not None:
200
- update_soap_covariances_(t, GGs_=state['GG'], beta=shampoo_beta)
201
- try: state['Q'] = get_orthogonal_matrix(state['GG'])
202
- except torch.linalg.LinAlgError as e:
203
- warnings.warn(f"torch.linalg.eigh raised an error when initializing SOAP Q matrices on 1st step, diagonal preconditioning will be used for this parameter. The error was:\n{e}")
204
- state["GG"] = None
205
-
206
- state['step'] = 0
207
- updates.append(tensors[i].clip(-0.1, 0.1))
208
- continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
209
- # I use scaled update instead as to not mess up with next modules.
210
-
211
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
212
- # i.e. projecting to the eigenbases of matrices in state['GG']
213
- t_projected = None
214
236
  if state['GG'] is not None:
215
- t_projected = project(t, state['Q'])
237
+ tensor = project(tensor, state['Q'])
216
238
 
217
- # exponential moving averages
218
- # this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
219
- exp_avg: torch.Tensor = state["exp_avg"]
220
- exp_avg_sq_projected: torch.Tensor = state["exp_avg_sq_projected"]
239
+ projected.append(tensor)
221
240
 
222
- exp_avg.lerp_(t, 1-beta1)
241
+ # ------------------------ run adam in projected space ----------------------- #
242
+ exp_avg_proj, exp_avg_sq_proj = unpack_states(states, tensors, "exp_avg_proj", "exp_avg_sq_proj", must_exist=True, cls=TensorList)
243
+ alpha, beta1, beta2, eps = unpack_dicts(settings, "alpha", "beta1", "beta2", "eps", cls=NumberList)
223
244
 
224
- if t_projected is None:
225
- exp_avg_sq_projected.mul_(beta2).addcmul_(t, t, value=1-beta2)
226
- else:
227
- exp_avg_sq_projected.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
245
+ # lerp exp_avg in projected space
246
+ if fs["proj_exp_avg"]:
247
+ exp_avg_proj.lerp_(projected, weight=1-beta1)
228
248
 
229
- # project exponential moving averages if they are accumulated unprojected
230
- exp_avg_projected = exp_avg
231
- if t_projected is not None:
232
- exp_avg_projected = project(exp_avg, state['Q'])
249
+ # or lerp in original space and project
250
+ else:
251
+ exp_avg = exp_avg_proj
252
+ exp_avg.lerp_(merged, weight=1-beta1)
253
+ exp_avg_proj = []
254
+ for t, state, setting in zip(exp_avg, states, settings):
255
+ if state['GG'] is not None:
256
+ t = project(t, state["Q"])
257
+ exp_avg_proj.append(t)
233
258
 
234
- denom = exp_avg_sq_projected.sqrt().add_(eps)
235
- # print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
259
+ exp_avg_sq_proj.mul_(beta2).addcmul_(projected, projected, value=1-beta2)
236
260
 
237
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
238
- # to the original space
239
- update = exp_avg_projected / denom
261
+ denom = exp_avg_sq_proj.sqrt().add_(eps)
262
+ dirs_proj = exp_avg_proj / denom
240
263
 
241
- if t_projected is not None:
242
- update = project_back(update, state["Q"])
264
+ # ------------------------------- project back ------------------------------- #
265
+ dirs: list[torch.Tensor] = []
266
+ for dir, state, setting in zip(dirs_proj, states, settings):
267
+ if state['GG'] is not None:
268
+ dir = project_back(dir, state['Q'])
243
269
 
244
- if setting['bias_correction']:
245
- bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
246
- bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
247
- update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
248
- elif alpha is not None:
249
- update *= alpha
270
+ if setting["merge_small"]:
271
+ dir = _unmerge_small_dims(dir, state['flat_sizes'], state['sort_idxs'])
250
272
 
251
- if merge_small:
252
- update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
273
+ dirs.append(dir)
253
274
 
254
- updates.append(update)
255
- state["step"] += 1
256
275
 
257
- # Update is done after the gradient step to avoid using current gradients in the projection.
276
+ # -------------------------------- inner step -------------------------------- #
277
+ if "inner" in self.children:
278
+ tensors = self.inner_step_tensors("inner", tensors, clone=False,
279
+ params=params, grads=grads,loss=loss)
280
+
281
+ # we now have to re-merge small dims on updated tensors
282
+ merged = []
283
+ for tensor, state, setting in zip(tensors, states, settings):
284
+ if setting["merge_small"]:
285
+ tensor, _, _ = _merge_small_dims(tensor, setting["max_dim"])
286
+ merged.append(tensor)
287
+
288
+ # -------------------------- update preconditioners -------------------------- #
289
+ # Update is done after the gradient step to avoid using current gradients in the projection.
290
+
291
+ for tensor, state, setting in zip(merged, states, settings):
258
292
  if state['GG'] is not None:
259
- update_soap_covariances_(t, state['GG'], shampoo_beta)
260
- if state['step'] % setting['precond_freq'] == 0:
293
+
294
+ # lerp covariances
295
+ update_soap_covariances_(tensor, state['GG'], beta=setting["shampoo_beta"])
296
+
297
+ # (state['step'] - 1) since we start updating on 2nd step
298
+ if (state['step'] - 1) % setting['precond_freq'] == 0:
299
+
300
+ # unproject exp_avg before updating if it is maintained projected
301
+ exp_avg = None
302
+ if fs["proj_exp_avg"]:
303
+ exp_avg = project_back(state["exp_avg_proj"], state["Q"])
304
+
305
+ # update projection matrix and exp_avg_sq_proj
261
306
  try:
262
- state['Q'], state['exp_avg_sq_projected'] = get_orthogonal_matrix_QR(exp_avg_sq_projected, state['GG'], state['Q'])
307
+ state['Q'], state['exp_avg_sq_proj'] = get_orthogonal_matrix_QR(
308
+ state["exp_avg_sq_proj"], state['GG'], state['Q'])
309
+
310
+ # re-project exp_avg if it is maintained projected
311
+ if fs["proj_exp_avg"]:
312
+ assert exp_avg is not None
313
+ state["exp_avg_proj"] = project(exp_avg, state["Q"])
314
+
263
315
  except torch.linalg.LinAlgError:
264
316
  pass
265
- return updates
317
+
318
+ state["step"] += 1
319
+
320
+
321
+ # ------------------------- bias-corrected step size ------------------------- #
322
+ if fs["debias"]:
323
+ steps1 = [s+1 for s in steps]
324
+ bias_correction1 = 1.0 - beta1 ** steps1
325
+ bias_correction2 = 1.0 - beta2 ** steps1
326
+ alpha = alpha * (bias_correction2 ** .5) / bias_correction1
327
+
328
+ torch._foreach_mul_(dirs, alpha)
329
+ return dirs