torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,10 @@
1
- from collections.abc import Sequence
2
- from operator import itemgetter
3
- from functools import partial
1
+ from collections.abc import Sequence, Iterable
2
+
4
3
  import numpy as np
5
4
  import torch
6
5
 
7
- from ...core import Chainable, Transform, apply_transform
8
- from ...utils.linalg import matrix_power_eigh
6
+ from ...core import Chainable, TensorTransform
7
+ from ...linalg.matrix_power import MatrixPowerMethod, matrix_power as _matrix_power
9
8
  from ...utils import set_storage_
10
9
 
11
10
 
@@ -14,10 +13,11 @@ def update_shampoo_preconditioner_(
14
13
  accumulators_: list[torch.Tensor | None],
15
14
  preconditioners_: list[torch.Tensor | None],
16
15
  step: int,
17
- update_freq: int,
18
- exp_override: int | None,
16
+ precond_freq: int,
17
+ matrix_power: float | None,
19
18
  beta: float | None,
20
- reg: float
19
+ reg: float,
20
+ matrix_power_method: MatrixPowerMethod,
21
21
  ):
22
22
  for i, (accumulator, preconditioner) in enumerate(zip(accumulators_, preconditioners_)):
23
23
  if accumulator is None: continue
@@ -27,22 +27,20 @@ def update_shampoo_preconditioner_(
27
27
  if beta is None: accumulator.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
28
28
  else: accumulator.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
29
29
 
30
- if step % update_freq == 0:
31
- matrix_exp = -1/(grad.ndim*2) if exp_override is None else -1/exp_override
30
+ if step % precond_freq == 0:
32
31
  if reg != 0:
33
32
  accumulator = accumulator + torch.eye(accumulator.size(0), device=accumulator.device, dtype=accumulator.dtype).mul_(reg)
34
- set_storage_(preconditioner, matrix_power_eigh(accumulator, matrix_exp))
35
33
 
34
+ if matrix_power is None: matrix_power = -1 / max(grad.ndim, 2)
35
+ set_storage_(preconditioner, _matrix_power(accumulator, matrix_power, method=matrix_power_method))
36
36
 
37
37
  def apply_shampoo_preconditioner(
38
38
  tensor: torch.Tensor,
39
39
  preconditioners_: list[torch.Tensor | None],
40
- decay: float | None,
41
40
  ):
42
41
  for i, preconditioner in enumerate(preconditioners_):
43
42
  if preconditioner is None: continue
44
43
  tensor = torch.tensordot(tensor, preconditioner, ([0], [0])) # pyright:ignore[reportArgumentType]
45
- if decay is not None: preconditioner.mul_(decay)
46
44
  return tensor
47
45
 
48
46
 
@@ -50,9 +48,8 @@ def update_diagonal_(grad: torch.Tensor, diagonal_accumulator_: torch.Tensor, be
50
48
  if beta is None: diagonal_accumulator_.add_(grad.pow(2))
51
49
  else: diagonal_accumulator_.mul_(beta).addcmul_(grad, grad, value=1-beta)
52
50
 
53
- def apply_diagonal_(grad_: torch.Tensor, diagonal_accumulator_: torch.Tensor, decay: float | None, eps: float):
51
+ def apply_diagonal_(grad_: torch.Tensor, diagonal_accumulator_: torch.Tensor, eps: float):
54
52
  grad_.div_(diagonal_accumulator_.sqrt() + eps)
55
- if decay is not None: diagonal_accumulator_.mul_(decay)
56
53
  return grad_
57
54
 
58
55
  def _merge_small_dims(tensor: torch.Tensor, max_dim: int):
@@ -85,145 +82,167 @@ def _unmerge_small_dims(tensor: torch.Tensor, flat_sizes: Sequence[int] | None,
85
82
  tensor = tensor.unflatten(0, flat_sizes)
86
83
  return tensor.permute(*np.argsort(sort_idxs).tolist())
87
84
 
85
+ def diagonal_memory(params: torch.nn.Module | torch.Tensor | Iterable[torch.Tensor]):
86
+ """computes number of parameters"""
87
+ if isinstance(params, torch.nn.Module): params = params.parameters()
88
+ if isinstance(params, torch.Tensor): params = [params,]
89
+ params = list(params)
90
+ return sum(p.numel() for p in params)
91
+
92
+ def kronecker_memory(params: torch.nn.Module | torch.Tensor | Iterable[torch.Tensor], merge_small:bool=True, max_dim:int=10_000):
93
+ """computes total size of tensors required to store shampoo preconditioner"""
94
+ if isinstance(params, torch.nn.Module): params = params.parameters()
95
+ if isinstance(params, torch.Tensor): params = [params,]
96
+ params = list(params)
97
+
98
+ memory = 0
99
+ for p in params:
100
+ if merge_small:
101
+ p, _, _ = _merge_small_dims(p, max_dim)
102
+ for dim in p.size():
103
+ if dim > max_dim: memory += dim
104
+ else: memory += dim**2
105
+
106
+ return memory
107
+
88
108
 
89
- class Shampoo(Transform):
109
+
110
+
111
+ class Shampoo(TensorTransform):
90
112
  """Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).
91
113
 
92
- .. note::
114
+ Notes:
93
115
  Shampoo is usually grafted to another optimizer like Adam, otherwise it can be unstable. An example of how to do grafting is given below in the Examples section.
94
116
 
95
- .. note::
96
- Shampoo is a very computationally expensive optimizer, increase :code:`update_freq` if it is too slow.
117
+ Shampoo is a very computationally expensive optimizer, increase ``update_freq`` if it is too slow.
97
118
 
98
- .. note::
99
- SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as :code:`tz.m.SOAP`.
119
+ SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as ``tz.m.SOAP``.
100
120
 
101
121
  Args:
102
- decay (float | None, optional): slowly decays preconditioners. Defaults to None.
103
- beta (float | None, optional):
104
- if None calculates sum as in standard shampoo, otherwise uses EMA of preconditioners. Defaults to None.
105
122
  update_freq (int, optional): preconditioner update frequency. Defaults to 10.
106
- exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to 2.
123
+ matrix_power (float | None, optional): overrides matrix exponent. By default uses ``-1/grad.ndim``. Defaults to None.
107
124
  merge_small (bool, optional): whether to merge small dims on tensors. Defaults to True.
108
- max_dim (int, optional): maximum dimension size for preconditioning. Defaults to 2_000.
125
+ max_dim (int, optional): maximum dimension size for preconditioning. Defaults to 10_000.
109
126
  precondition_1d (bool, optional): whether to precondition 1d tensors. Defaults to True.
110
127
  adagrad_eps (float, optional): epsilon for adagrad division for tensors where shampoo can't be applied. Defaults to 1e-8.
128
+ matrix_power_method (MatrixPowerMethod, optional): how to compute matrix power.
129
+ beta (float | None, optional):
130
+ if None calculates sum as in standard Shampoo, otherwise uses EMA of preconditioners. Defaults to None.
111
131
  inner (Chainable | None, optional):
112
132
  module applied after updating preconditioners and before applying preconditioning.
113
133
  For example if beta≈0.999 and `inner=tz.m.EMA(0.9)`, this becomes Adam with shampoo preconditioner (ignoring debiasing).
114
134
  Defaults to None.
115
135
 
116
136
  Examples:
117
- Shampoo grafted to Adam
118
-
119
- .. code-block:: python
120
-
121
- opt = tz.Modular(
122
- model.parameters(),
123
- tz.m.GraftModules(
124
- direction = tz.m.Shampoo(),
125
- magnitude = tz.m.Adam(),
126
- ),
127
- tz.m.LR(1e-3)
128
- )
129
-
130
- Adam with Shampoo preconditioner
131
-
132
- .. code-block:: python
133
-
134
- opt = tz.Modular(
135
- model.parameters(),
136
- tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
137
- tz.m.Debias(0.9, 0.999),
138
- tz.m.LR(1e-3)
139
- )
137
+ Shampoo grafted to Adam
138
+
139
+ ```python
140
+ opt = tz.Optimizer(
141
+ model.parameters(),
142
+ tz.m.GraftModules(
143
+ direction = tz.m.Shampoo(),
144
+ magnitude = tz.m.Adam(),
145
+ ),
146
+ tz.m.LR(1e-3)
147
+ )
148
+ ```
149
+
150
+ Adam with Shampoo preconditioner
151
+
152
+ ```python
153
+ opt = tz.Optimizer(
154
+ model.parameters(),
155
+ tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
156
+ tz.m.Debias(0.9, 0.999),
157
+ tz.m.LR(1e-3)
158
+ )
159
+ ```
140
160
  """
141
161
  def __init__(
142
162
  self,
143
- decay: float | None = None,
144
- beta: float | None = None,
145
163
  reg: float = 1e-12,
146
- update_freq: int = 10,
147
- exp_override: int | None = 2,
164
+ precond_freq: int = 10,
165
+ matrix_power: float | None = None,
148
166
  merge_small: bool = True,
149
- max_dim: int = 2_000,
167
+ max_dim: int = 10_000,
150
168
  precondition_1d: bool = True,
151
169
  adagrad_eps: float = 1e-8,
170
+ matrix_power_method: MatrixPowerMethod = "eigh_abs",
171
+ beta: float | None = None,
172
+ beta_debias: bool = True,
173
+
152
174
  inner: Chainable | None = None,
153
175
  ):
154
- defaults = dict(decay=decay, beta=beta, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps, reg=reg)
155
- super().__init__(defaults, uses_grad=False)
156
-
157
- if inner is not None:
158
- self.set_child('inner', inner)
159
-
160
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
161
- merged_tensors = [] # target with merged dims
162
-
163
- # update preconditioners
164
- for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
165
- beta, update_freq, exp_override, merge_small, max_dim, precondition_1d, reg = itemgetter(
166
- 'beta', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d', "reg")(setting)
167
-
168
- if merge_small:
169
- t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
170
-
171
- merged_tensors.append(t)
172
-
173
- # initialize accumulators and preconditioners for each dim on 1st step
174
- if 'accumulators' not in state:
175
-
176
- if not precondition_1d and t.ndim <= 1:
177
- state['accumulators'] = []
178
-
179
- else:
180
- state['accumulators'] = [torch.eye(s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
181
- state['preconditioners'] = [torch.eye(s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
182
-
183
- # either scalar parameter, 1d with precondition_1d=False, or too big, then basic diagonal preconditioner is used.
184
- if len([i is not None for i in state['accumulators']]) == 0:
185
- state['diagonal_accumulator'] = torch.zeros_like(t)
186
-
187
- state['step'] = 0
188
-
189
- # update preconditioners
190
- if 'diagonal_accumulator' in state:
191
- update_diagonal_(t, state['diagonal_accumulator'], beta)
192
- else:
193
- update_shampoo_preconditioner_(
194
- t,
195
- accumulators_=state['accumulators'],
196
- preconditioners_=state['preconditioners'],
197
- step=state['step'],
198
- update_freq=update_freq,
199
- exp_override=exp_override,
200
- beta=beta,
201
- reg=reg,
202
- )
203
-
204
- # inner step
205
- if 'inner' in self.children:
206
- tensors = apply_transform(self.children['inner'], tensors, params=params, grads=grads)
207
-
208
- # have to merge small dims again
209
- merged_tensors = [] # target with merged dims
210
- for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
211
- if setting['merge_small']:
212
- t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, setting['max_dim'])
213
- merged_tensors.append(t)
214
-
215
- # precondition
216
- for i,(t,state, setting) in enumerate(zip(merged_tensors, states, settings)):
217
- decay, merge_small, adagrad_eps= itemgetter('decay', 'merge_small', 'adagrad_eps')(setting)
218
-
219
- if 'diagonal_accumulator' in state:
220
- tensors[i] = apply_diagonal_(t, state['diagonal_accumulator'], decay=decay, eps=adagrad_eps)
221
- else:
222
- tensors[i] = apply_shampoo_preconditioner(t, preconditioners_=state['preconditioners'], decay=decay)
223
-
224
- if merge_small:
225
- tensors[i] = _unmerge_small_dims(tensors[i], state['flat_sizes'], state['sort_idxs'])
226
-
227
- state['step'] += 1
228
-
229
- return tensors
176
+ defaults = locals().copy()
177
+ del defaults['self'], defaults["inner"]
178
+
179
+ super().__init__(defaults, inner=inner)
180
+
181
+ @torch.no_grad
182
+ def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
183
+ if setting["merge_small"]:
184
+ tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
185
+
186
+ if tensor.ndim <= 1 and not setting["precondition_1d"]:
187
+ state["accumulators"] = []
188
+
189
+ else:
190
+ max_dim = setting["max_dim"]
191
+ state['accumulators'] = [
192
+ torch.eye(s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
193
+ ]
194
+ state['preconditioners'] = [
195
+ torch.eye(s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
196
+ ]
197
+
198
+ # either scalar parameter, 1d with precondition_1d=False, or too big, then diagonal preconditioner is used.
199
+ if len([i is not None for i in state['accumulators']]) == 0:
200
+ state['diagonal_accumulator'] = torch.zeros_like(tensor)
201
+
202
+ state['step'] = 0
203
+ state["num_GTG"] = 0
204
+
205
+ @torch.no_grad
206
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
207
+ if setting["merge_small"]:
208
+ tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
209
+
210
+ if 'diagonal_accumulator' in state:
211
+ update_diagonal_(tensor, state['diagonal_accumulator'], beta=setting["beta"])
212
+ else:
213
+ update_shampoo_preconditioner_(
214
+ tensor,
215
+ accumulators_=state['accumulators'],
216
+ preconditioners_=state['preconditioners'],
217
+ step=state['step'],
218
+ precond_freq=setting["precond_freq"],
219
+ matrix_power=setting["matrix_power"],
220
+ beta=setting["beta"],
221
+ reg=setting["reg"],
222
+ matrix_power_method=setting["matrix_power_method"],
223
+ )
224
+
225
+ if state["step"] % setting["precond_freq"] == 0:
226
+ state["num_GTG"] += 1
227
+
228
+ state["step"] += 1
229
+
230
+ @torch.no_grad
231
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
232
+ if setting["merge_small"]:
233
+ tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
234
+
235
+ if 'diagonal_accumulator' in state:
236
+ dir = apply_diagonal_(tensor, state['diagonal_accumulator'], eps=setting["adagrad_eps"])
237
+ else:
238
+ dir = apply_shampoo_preconditioner(tensor, preconditioners_=state['preconditioners'])
239
+
240
+ if setting["merge_small"]:
241
+ dir = _unmerge_small_dims(dir, state['flat_sizes'], state['sort_idxs'])
242
+
243
+ if setting['beta_debias'] and setting["beta"] is not None:
244
+ bias_correction = 1 - (setting["beta"] ** state["num_GTG"])
245
+ dir *= bias_correction ** 0.5
246
+
247
+ return dir
248
+