torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,10 @@
1
1
  from contextlib import nullcontext
2
2
  import torch
3
- from ...utils import TensorList, NumberList
4
- from ...core import Module
3
+ from ...utils import TensorList, NumberList, unpack_dicts, unpack_states
4
+ from ...core import Transform
5
5
 
6
6
 
7
- class SAM(Module):
7
+ class SAM(Transform):
8
8
  """Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412
9
9
 
10
10
  SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
@@ -22,50 +22,51 @@ class SAM(Module):
22
22
  p (float, optional): norm of the SAM objective. Defaults to 2.
23
23
  asam (bool, optional):
24
24
  enables ASAM variant which makes perturbation relative to weight magnitudes.
25
- ASAM requires a much larger :code:`rho`, like 0.5 or 1.
26
- The :code:`tz.m.ASAM` class is idential to setting this argument to True, but
27
- it has larger :code:`rho` by default.
25
+ ASAM requires a much larger ``rho``, like 0.5 or 1.
26
+ The ``tz.m.ASAM`` class is idential to setting this argument to True, but
27
+ it has larger ``rho`` by default.
28
28
 
29
- Examples:
30
- SAM-SGD:
29
+ ### Examples:
31
30
 
32
- .. code-block:: python
31
+ SAM-SGD:
33
32
 
34
- opt = tz.Modular(
35
- model.parameters(),
36
- tz.m.SAM(),
37
- tz.m.LR(1e-2)
38
- )
33
+ ```py
34
+ opt = tz.Modular(
35
+ model.parameters(),
36
+ tz.m.SAM(),
37
+ tz.m.LR(1e-2)
38
+ )
39
+ ```
39
40
 
40
- SAM-Adam:
41
+ SAM-Adam:
41
42
 
42
- .. code-block:: python
43
-
44
- opt = tz.Modular(
45
- model.parameters(),
46
- tz.m.SAM(),
47
- tz.m.Adam(),
48
- tz.m.LR(1e-2)
49
- )
43
+ ```
44
+ opt = tz.Modular(
45
+ model.parameters(),
46
+ tz.m.SAM(),
47
+ tz.m.Adam(),
48
+ tz.m.LR(1e-2)
49
+ )
50
+ ```
50
51
 
51
52
  References:
52
- Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412. https://arxiv.org/abs/2010.01412#page=3.16
53
+ [Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412.](https://arxiv.org/abs/2010.01412#page=3.16)
53
54
  """
54
55
  def __init__(self, rho: float = 0.05, p: float = 2, eps=1e-10, asam=False):
55
56
  defaults = dict(rho=rho, p=p, eps=eps, asam=asam)
56
57
  super().__init__(defaults)
57
58
 
58
59
  @torch.no_grad
59
- def step(self, var):
60
+ def update_states(self, objective, states, settings):
60
61
 
61
- params = var.params
62
- closure = var.closure
63
- zero_grad = var.zero_grad
62
+ params = objective.params
63
+ closure = objective.closure
64
+ zero_grad = objective.zero_grad
64
65
  if closure is None: raise RuntimeError("SAM requires a closure passed to the optimizer step")
65
- p, rho = self.get_settings(var.params, 'p', 'rho', cls=NumberList)
66
- s = self.defaults
67
- eps = s['eps']
68
- asam = s['asam']
66
+ p, rho = unpack_dicts(settings, 'p', 'rho', cls=NumberList)
67
+ fs = settings[0]
68
+ eps = fs['eps']
69
+ asam = fs['asam']
69
70
 
70
71
  # 1/p + 1/q = 1
71
72
  # okay, authors of SAM paper, I will manually solve your equation
@@ -123,8 +124,7 @@ class SAM(Module):
123
124
 
124
125
  return sam_loss
125
126
 
126
- var.closure = sam_closure
127
- return var
127
+ objective.closure = sam_closure
128
128
 
129
129
  # different class because defaults for SAM are bad for ASAM
130
130
  class ASAM(SAM):
@@ -136,7 +136,7 @@ class ASAM(SAM):
136
136
  This implementation modifies the closure to return loss and calculate gradients
137
137
  of the SAM objective. All modules after this will use the modified objective.
138
138
 
139
- .. note::
139
+ Note:
140
140
  This module requires a closure passed to the optimizer step,
141
141
  as it needs to re-evaluate the loss and gradients at two points on each step.
142
142
 
@@ -144,20 +144,30 @@ class ASAM(SAM):
144
144
  rho (float, optional): Neighborhood size. Defaults to 0.05.
145
145
  p (float, optional): norm of the SAM objective. Defaults to 2.
146
146
 
147
- Examples:
148
- ASAM-Adam:
147
+ ### Examples:
148
+
149
+ ASAM-SGD:
149
150
 
150
- .. code-block:: python
151
+ ```py
152
+ opt = tz.Modular(
153
+ model.parameters(),
154
+ tz.m.ASAM(),
155
+ tz.m.LR(1e-2)
156
+ )
157
+ ```
151
158
 
152
- opt = tz.Modular(
153
- model.parameters(),
154
- tz.m.ASAM(),
155
- tz.m.Adam(),
156
- tz.m.LR(1e-2)
157
- )
159
+ ASAM-Adam:
158
160
 
161
+ ```
162
+ opt = tz.Modular(
163
+ model.parameters(),
164
+ tz.m.ASAM(),
165
+ tz.m.Adam(),
166
+ tz.m.LR(1e-2)
167
+ )
168
+ ```
159
169
  References:
160
- Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). Asam: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR. https://arxiv.org/abs/2102.11600
170
+ [Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). ASAM: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR.](https://arxiv.org/abs/2102.11600)
161
171
  """
162
172
  def __init__(self, rho: float = 0.5, p: float = 2, eps=1e-10):
163
173
  super().__init__(rho=rho, p=p, eps=eps, asam=True)
@@ -1,11 +1,10 @@
1
1
  from collections.abc import Sequence
2
- from operator import itemgetter
3
- from functools import partial
2
+
4
3
  import numpy as np
5
4
  import torch
6
5
 
7
- from ...core import Chainable, Transform, apply_transform
8
- from ...utils.linalg import matrix_power_eigh
6
+ from ...core import Chainable, TensorTransform
7
+ from ...linalg.matrix_power import MatrixPowerMethod, matrix_power as _matrix_power
9
8
  from ...utils import set_storage_
10
9
 
11
10
 
@@ -14,10 +13,11 @@ def update_shampoo_preconditioner_(
14
13
  accumulators_: list[torch.Tensor | None],
15
14
  preconditioners_: list[torch.Tensor | None],
16
15
  step: int,
17
- update_freq: int,
18
- exp_override: int | None,
16
+ precond_freq: int,
17
+ matrix_power: float | None,
19
18
  beta: float | None,
20
- reg: float
19
+ reg: float,
20
+ matrix_power_method: MatrixPowerMethod,
21
21
  ):
22
22
  for i, (accumulator, preconditioner) in enumerate(zip(accumulators_, preconditioners_)):
23
23
  if accumulator is None: continue
@@ -27,22 +27,20 @@ def update_shampoo_preconditioner_(
27
27
  if beta is None: accumulator.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
28
28
  else: accumulator.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
29
29
 
30
- if step % update_freq == 0:
31
- matrix_exp = -1/(grad.ndim*2) if exp_override is None else -1/exp_override
30
+ if step % precond_freq == 0:
32
31
  if reg != 0:
33
32
  accumulator = accumulator + torch.eye(accumulator.size(0), device=accumulator.device, dtype=accumulator.dtype).mul_(reg)
34
- set_storage_(preconditioner, matrix_power_eigh(accumulator, matrix_exp))
35
33
 
34
+ if matrix_power is None: matrix_power = -1 / max(grad.ndim, 2)
35
+ set_storage_(preconditioner, _matrix_power(accumulator, matrix_power, method=matrix_power_method))
36
36
 
37
37
  def apply_shampoo_preconditioner(
38
38
  tensor: torch.Tensor,
39
39
  preconditioners_: list[torch.Tensor | None],
40
- decay: float | None,
41
40
  ):
42
41
  for i, preconditioner in enumerate(preconditioners_):
43
42
  if preconditioner is None: continue
44
43
  tensor = torch.tensordot(tensor, preconditioner, ([0], [0])) # pyright:ignore[reportArgumentType]
45
- if decay is not None: preconditioner.mul_(decay)
46
44
  return tensor
47
45
 
48
46
 
@@ -50,9 +48,8 @@ def update_diagonal_(grad: torch.Tensor, diagonal_accumulator_: torch.Tensor, be
50
48
  if beta is None: diagonal_accumulator_.add_(grad.pow(2))
51
49
  else: diagonal_accumulator_.mul_(beta).addcmul_(grad, grad, value=1-beta)
52
50
 
53
- def apply_diagonal_(grad_: torch.Tensor, diagonal_accumulator_: torch.Tensor, decay: float | None, eps: float):
51
+ def apply_diagonal_(grad_: torch.Tensor, diagonal_accumulator_: torch.Tensor, eps: float):
54
52
  grad_.div_(diagonal_accumulator_.sqrt() + eps)
55
- if decay is not None: diagonal_accumulator_.mul_(decay)
56
53
  return grad_
57
54
 
58
55
  def _merge_small_dims(tensor: torch.Tensor, max_dim: int):
@@ -86,144 +83,141 @@ def _unmerge_small_dims(tensor: torch.Tensor, flat_sizes: Sequence[int] | None,
86
83
  return tensor.permute(*np.argsort(sort_idxs).tolist())
87
84
 
88
85
 
89
- class Shampoo(Transform):
86
+ class Shampoo(TensorTransform):
90
87
  """Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).
91
88
 
92
- .. note::
89
+ Notes:
93
90
  Shampoo is usually grafted to another optimizer like Adam, otherwise it can be unstable. An example of how to do grafting is given below in the Examples section.
94
91
 
95
- .. note::
96
- Shampoo is a very computationally expensive optimizer, increase :code:`update_freq` if it is too slow.
92
+ Shampoo is a very computationally expensive optimizer, increase ``update_freq`` if it is too slow.
97
93
 
98
- .. note::
99
- SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as :code:`tz.m.SOAP`.
94
+ SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as ``tz.m.SOAP``.
100
95
 
101
96
  Args:
102
- decay (float | None, optional): slowly decays preconditioners. Defaults to None.
103
- beta (float | None, optional):
104
- if None calculates sum as in standard shampoo, otherwise uses EMA of preconditioners. Defaults to None.
105
97
  update_freq (int, optional): preconditioner update frequency. Defaults to 10.
106
- exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to 2.
98
+ matrix_power (float | None, optional): overrides matrix exponent. By default uses ``-1/grad.ndim``. Defaults to None.
107
99
  merge_small (bool, optional): whether to merge small dims on tensors. Defaults to True.
108
- max_dim (int, optional): maximum dimension size for preconditioning. Defaults to 2_000.
100
+ max_dim (int, optional): maximum dimension size for preconditioning. Defaults to 10_000.
109
101
  precondition_1d (bool, optional): whether to precondition 1d tensors. Defaults to True.
110
102
  adagrad_eps (float, optional): epsilon for adagrad division for tensors where shampoo can't be applied. Defaults to 1e-8.
103
+ matrix_power_method (MatrixPowerMethod, optional): how to compute matrix power.
104
+ beta (float | None, optional):
105
+ if None calculates sum as in standard Shampoo, otherwise uses EMA of preconditioners. Defaults to None.
111
106
  inner (Chainable | None, optional):
112
107
  module applied after updating preconditioners and before applying preconditioning.
113
108
  For example if beta≈0.999 and `inner=tz.m.EMA(0.9)`, this becomes Adam with shampoo preconditioner (ignoring debiasing).
114
109
  Defaults to None.
115
110
 
116
111
  Examples:
117
- Shampoo grafted to Adam
118
-
119
- .. code-block:: python
120
-
121
- opt = tz.Modular(
122
- model.parameters(),
123
- tz.m.GraftModules(
124
- direction = tz.m.Shampoo(),
125
- magnitude = tz.m.Adam(),
126
- ),
127
- tz.m.LR(1e-3)
128
- )
129
-
130
- Adam with Shampoo preconditioner
131
-
132
- .. code-block:: python
133
-
134
- opt = tz.Modular(
135
- model.parameters(),
136
- tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
137
- tz.m.Debias(0.9, 0.999),
138
- tz.m.LR(1e-3)
139
- )
112
+ Shampoo grafted to Adam
113
+
114
+ ```python
115
+ opt = tz.Modular(
116
+ model.parameters(),
117
+ tz.m.GraftModules(
118
+ direction = tz.m.Shampoo(),
119
+ magnitude = tz.m.Adam(),
120
+ ),
121
+ tz.m.LR(1e-3)
122
+ )
123
+ ```
124
+
125
+ Adam with Shampoo preconditioner
126
+
127
+ ```python
128
+ opt = tz.Modular(
129
+ model.parameters(),
130
+ tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
131
+ tz.m.Debias(0.9, 0.999),
132
+ tz.m.LR(1e-3)
133
+ )
134
+ ```
140
135
  """
141
136
  def __init__(
142
137
  self,
143
- decay: float | None = None,
144
- beta: float | None = None,
145
138
  reg: float = 1e-12,
146
- update_freq: int = 10,
147
- exp_override: int | None = 2,
139
+ precond_freq: int = 10,
140
+ matrix_power: float | None = None,
148
141
  merge_small: bool = True,
149
- max_dim: int = 2_000,
142
+ max_dim: int = 10_000,
150
143
  precondition_1d: bool = True,
151
144
  adagrad_eps: float = 1e-8,
145
+ matrix_power_method: MatrixPowerMethod = "eigh_abs",
146
+ beta: float | None = None,
147
+ beta_debias: bool = True,
148
+
152
149
  inner: Chainable | None = None,
153
150
  ):
154
- defaults = dict(decay=decay, beta=beta, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps, reg=reg)
155
- super().__init__(defaults, uses_grad=False)
156
-
157
- if inner is not None:
158
- self.set_child('inner', inner)
159
-
160
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
161
- merged_tensors = [] # target with merged dims
162
-
163
- # update preconditioners
164
- for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
165
- beta, update_freq, exp_override, merge_small, max_dim, precondition_1d, reg = itemgetter(
166
- 'beta', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d', "reg")(setting)
167
-
168
- if merge_small:
169
- t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
170
-
171
- merged_tensors.append(t)
172
-
173
- # initialize accumulators and preconditioners for each dim on 1st step
174
- if 'accumulators' not in state:
175
-
176
- if not precondition_1d and t.ndim <= 1:
177
- state['accumulators'] = []
178
-
179
- else:
180
- state['accumulators'] = [torch.eye(s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
181
- state['preconditioners'] = [torch.eye(s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
182
-
183
- # either scalar parameter, 1d with precondition_1d=False, or too big, then basic diagonal preconditioner is used.
184
- if len([i is not None for i in state['accumulators']]) == 0:
185
- state['diagonal_accumulator'] = torch.zeros_like(t)
186
-
187
- state['step'] = 0
188
-
189
- # update preconditioners
190
- if 'diagonal_accumulator' in state:
191
- update_diagonal_(t, state['diagonal_accumulator'], beta)
192
- else:
193
- update_shampoo_preconditioner_(
194
- t,
195
- accumulators_=state['accumulators'],
196
- preconditioners_=state['preconditioners'],
197
- step=state['step'],
198
- update_freq=update_freq,
199
- exp_override=exp_override,
200
- beta=beta,
201
- reg=reg,
202
- )
203
-
204
- # inner step
205
- if 'inner' in self.children:
206
- tensors = apply_transform(self.children['inner'], tensors, params=params, grads=grads)
207
-
208
- # have to merge small dims again
209
- merged_tensors = [] # target with merged dims
210
- for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
211
- if setting['merge_small']:
212
- t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, setting['max_dim'])
213
- merged_tensors.append(t)
214
-
215
- # precondition
216
- for i,(t,state, setting) in enumerate(zip(merged_tensors, states, settings)):
217
- decay, merge_small, adagrad_eps= itemgetter('decay', 'merge_small', 'adagrad_eps')(setting)
218
-
219
- if 'diagonal_accumulator' in state:
220
- tensors[i] = apply_diagonal_(t, state['diagonal_accumulator'], decay=decay, eps=adagrad_eps)
221
- else:
222
- tensors[i] = apply_shampoo_preconditioner(t, preconditioners_=state['preconditioners'], decay=decay)
223
-
224
- if merge_small:
225
- tensors[i] = _unmerge_small_dims(tensors[i], state['flat_sizes'], state['sort_idxs'])
226
-
227
- state['step'] += 1
228
-
229
- return tensors
151
+ defaults = locals().copy()
152
+ del defaults['self'], defaults["inner"]
153
+
154
+ super().__init__(defaults, inner=inner)
155
+
156
+ @torch.no_grad
157
+ def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
158
+ if setting["merge_small"]:
159
+ tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
160
+
161
+ if tensor.ndim <= 1 and not setting["precondition_1d"]:
162
+ state["accumulators"] = []
163
+
164
+ else:
165
+ max_dim = setting["max_dim"]
166
+ state['accumulators'] = [
167
+ torch.eye(s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
168
+ ]
169
+ state['preconditioners'] = [
170
+ torch.eye(s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
171
+ ]
172
+
173
+ # either scalar parameter, 1d with precondition_1d=False, or too big, then diagonal preconditioner is used.
174
+ if len([i is not None for i in state['accumulators']]) == 0:
175
+ state['diagonal_accumulator'] = torch.zeros_like(tensor)
176
+
177
+ state['step'] = 0
178
+ state["num_GTG"] = 0
179
+
180
+ @torch.no_grad
181
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
182
+ if setting["merge_small"]:
183
+ tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
184
+
185
+ if 'diagonal_accumulator' in state:
186
+ update_diagonal_(tensor, state['diagonal_accumulator'], beta=setting["beta"])
187
+ else:
188
+ update_shampoo_preconditioner_(
189
+ tensor,
190
+ accumulators_=state['accumulators'],
191
+ preconditioners_=state['preconditioners'],
192
+ step=state['step'],
193
+ precond_freq=setting["precond_freq"],
194
+ matrix_power=setting["matrix_power"],
195
+ beta=setting["beta"],
196
+ reg=setting["reg"],
197
+ matrix_power_method=setting["matrix_power_method"],
198
+ )
199
+
200
+ if state["step"] % setting["precond_freq"] == 0:
201
+ state["num_GTG"] += 1
202
+
203
+ state["step"] += 1
204
+
205
+ @torch.no_grad
206
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
207
+ if setting["merge_small"]:
208
+ tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
209
+
210
+ if 'diagonal_accumulator' in state:
211
+ dir = apply_diagonal_(tensor, state['diagonal_accumulator'], eps=setting["adagrad_eps"])
212
+ else:
213
+ dir = apply_shampoo_preconditioner(tensor, preconditioners_=state['preconditioners'])
214
+
215
+ if setting["merge_small"]:
216
+ dir = _unmerge_small_dims(dir, state['flat_sizes'], state['sort_idxs'])
217
+
218
+ if setting['beta_debias'] and setting["beta"] is not None:
219
+ bias_correction = 1 - (setting["beta"] ** state["num_GTG"])
220
+ dir *= bias_correction ** 0.5
221
+
222
+ return dir
223
+