torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -3,21 +3,14 @@ from typing import Literal
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import (
7
- Chainable,
8
- Modular,
9
- Module,
10
- Transform,
11
- Var,
12
- apply_transform,
13
- )
14
- from ...utils import TensorList, as_tensorlist, unpack_dicts, unpack_states
15
- from ..line_search import LineSearchBase
6
+ from ...core import Chainable, TensorTransform
7
+
8
+ from ...utils import TensorList, safe_dict_update_, unpack_dicts, unpack_states
16
9
  from ..quasi_newton.quasi_newton import HessianUpdateStrategy
17
- from ..functional import safe_clip
10
+ from ..opt_utils import safe_clip
18
11
 
19
12
 
20
- class ConguateGradientBase(Transform, ABC):
13
+ class ConguateGradientBase(TensorTransform, ABC):
21
14
  """Base class for conjugate gradient methods. The only difference between them is how beta is calculated.
22
15
 
23
16
  This is an abstract class, to use it, subclass it and override `get_beta`.
@@ -52,13 +45,8 @@ class ConguateGradientBase(Transform, ABC):
52
45
  """
53
46
  def __init__(self, defaults, clip_beta: bool, restart_interval: int | None | Literal['auto'], inner: Chainable | None = None):
54
47
  if defaults is None: defaults = {}
55
- defaults['restart_interval'] = restart_interval
56
- defaults['clip_beta'] = clip_beta
57
- super().__init__(defaults, uses_grad=False)
58
-
59
- if inner is not None:
60
- self.set_child('inner', inner)
61
-
48
+ safe_dict_update_(defaults, dict(restart_interval=restart_interval, clip_beta=clip_beta))
49
+ super().__init__(defaults, inner=inner)
62
50
 
63
51
  def reset_for_online(self):
64
52
  super().reset_for_online()
@@ -74,40 +62,38 @@ class ConguateGradientBase(Transform, ABC):
74
62
  """returns beta"""
75
63
 
76
64
  @torch.no_grad
77
- def update_tensors(self, tensors, params, grads, loss, states, settings):
78
- tensors = as_tensorlist(tensors)
79
- params = as_tensorlist(params)
80
-
81
- step = self.global_state.get('step', 0) + 1
82
- self.global_state['step'] = step
65
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
66
+ tensors = TensorList(tensors)
67
+ params = TensorList(params)
68
+ self.increment_counter("step", start=0)
83
69
 
84
70
  # initialize on first step
85
- if self.global_state.get('stage', 0) == 0:
71
+ if self.global_state.get('stage', "first update") == "first update":
86
72
  g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
87
73
  d_prev.copy_(tensors)
88
74
  g_prev.copy_(tensors)
89
75
  self.initialize(params, tensors)
90
- self.global_state['stage'] = 1
76
+ self.global_state['stage'] = "first apply"
91
77
 
92
78
  else:
93
79
  # if `update_tensors` was called multiple times before `apply_tensors`,
94
80
  # stage becomes 2
95
- self.global_state['stage'] = 2
81
+ self.global_state['stage'] = "initialized"
96
82
 
97
83
  @torch.no_grad
98
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
99
- tensors = as_tensorlist(tensors)
84
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
85
+ tensors = TensorList(tensors)
100
86
  step = self.global_state['step']
101
87
 
102
- if 'inner' in self.children:
103
- tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
88
+ assert self.global_state['stage'] != "first update"
104
89
 
105
- assert self.global_state['stage'] != 0
106
- if self.global_state['stage'] == 1:
107
- self.global_state['stage'] = 2
90
+ # on 1st apply we don't have previous gradients
91
+ # so just return tensors
92
+ if self.global_state['stage'] == "first apply":
93
+ self.global_state['stage'] = "initialized"
108
94
  return tensors
109
95
 
110
- params = as_tensorlist(params)
96
+ params = TensorList(params)
111
97
  g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
112
98
 
113
99
  # get beta
@@ -119,10 +105,13 @@ class ConguateGradientBase(Transform, ABC):
119
105
  dir = tensors.add_(d_prev.mul_(beta))
120
106
  d_prev.copy_(dir)
121
107
 
122
- # resetting
108
+ # resetting every `reset_interval` steps, use step+1 to not reset on 1st step
109
+ # so if reset_interval=2, then 1st step collects g_prev and d_prev, then
110
+ # two steps will happen until reset.
123
111
  restart_interval = settings[0]['restart_interval']
124
112
  if restart_interval == 'auto': restart_interval = tensors.global_numel() + 1
125
- if restart_interval is not None and step % restart_interval == 0:
113
+
114
+ if restart_interval is not None and (step + 1) % restart_interval == 0:
126
115
  self.state.clear()
127
116
  self.global_state.clear()
128
117
 
@@ -1,19 +1,20 @@
1
1
  """Those are various ideas of mine plus some other modules that I decided not to move to other sub-packages for whatever reason. This is generally less tested and shouldn't be used."""
2
+ from .adanystrom import AdaNystrom
3
+ from .common_directions_whiten import CommonDirectionsWhiten
4
+ from .coordinate_momentum import CoordinateMomentum
5
+ from .cubic_adam import CubicAdam, SubspaceCubicAdam
2
6
  from .curveball import CurveBall
7
+ from .eigen_sr1 import EigenSR1
3
8
 
4
9
  # from dct import DCTProjection
10
+ from .eigengrad import Eigengrad
5
11
  from .fft import FFTProjection
6
12
  from .gradmin import GradMin
7
13
  from .higher_order_newton import HigherOrderNewton
8
14
  from .l_infinity import InfinityNormTrustRegion
9
- from .momentum import (
10
- CoordinateMomentum,
11
- NesterovEMASquared,
12
- PrecenteredEMASquared,
13
- SqrtNesterovEMASquared,
14
- )
15
15
  from .newton_solver import NewtonSolver
16
16
  from .newtonnewton import NewtonNewton
17
17
  from .reduce_outward_lr import ReduceOutwardLR
18
18
  from .scipy_newton_cg import ScipyNewtonCG
19
+ from .spsa1 import SPSA1
19
20
  from .structural_projections import BlockPartition, TensorizeProjection
@@ -0,0 +1,258 @@
1
+ # pylint: disable = non-ascii-name
2
+ import torch
3
+
4
+ from ...core import Chainable, TensorTransform
5
+ from ...linalg import (
6
+ OrthogonalizeMethod,
7
+ orthogonalize,
8
+ regularize_eigh,
9
+ torch_linalg,
10
+ )
11
+ from ...linalg.linear_operator import Eigendecomposition
12
+ from ..adaptive.lre_optimizers import LREOptimizerBase
13
+ from .eigengrad import _eigengrad_update_state_, eigengrad_apply
14
+
15
+
16
+ def weighted_eigen_plus_rank1_mm(
17
+ # A1 = Q1 @ diag(L1) @ Q1.T
18
+ L1: torch.Tensor,
19
+ Q1: torch.Tensor,
20
+
21
+ # K2 = v2 @ v2.T
22
+ v2: torch.Tensor,
23
+
24
+ # second matrix
25
+ B: torch.Tensor,
26
+
27
+ # weights
28
+ w1: float,
29
+ w2: float,
30
+
31
+ ) -> torch.Tensor:
32
+ """
33
+ Computes ``(w1 * A1 + w2 * A2) @ B``, where ``A1`` is an eigendecomposition, ``A2`` is symmetric rank 1.
34
+
35
+ Returns ``(n, k)``
36
+
37
+ Args:
38
+ L1 (torch.Tensor): eigenvalues of A1, shape ``(rank,)``.
39
+ Q1 (torch.Tensor): eigenvectors of A1, shape ``(n, rank)``.
40
+ v2 (torch.Tensor): vector such that ``v v^T = A2``, shape ``(n,)``.
41
+ B (torch.Tensor): shape ``(n, k)``.
42
+ w1 (float): weight for A1.
43
+ w2 (float): weight for A2.
44
+
45
+ """
46
+ # sketch A1
47
+ QTB = Q1.T @ B # (rank, k)
48
+ LQTB = L1.unsqueeze(1) * QTB # (rank, k)
49
+ sketch1 = Q1 @ LQTB # (n, k)
50
+
51
+ # skecth A2
52
+ vB = v2 @ B
53
+ sketch2 = v2.outer(vB)
54
+
55
+ return w1 * sketch1 + w2 * sketch2
56
+
57
+
58
+ def adanystrom_update(
59
+ L1: torch.Tensor,
60
+ Q1: torch.Tensor,
61
+ v2: torch.Tensor,
62
+ w1: float,
63
+ w2: float,
64
+ oversampling_p: int,
65
+ rank: int,
66
+ eig_tol: float,
67
+ damping: float,
68
+ rdamping: float,
69
+ orthogonalize_method: OrthogonalizeMethod,
70
+
71
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
72
+ """computes the Nyström approximation of ``(w1 * A1 + w2 * A2)``,
73
+ where ``A1`` is an eigendecomposition, ``A2`` is symmetric rank 1.
74
+
75
+ returns L of shape ``(k, )`` and Q of shape ``(n, k)``.
76
+
77
+ Args:
78
+ L1 (torch.Tensor): eigenvalues of A1, shape ``(rank,)``.
79
+ Q1 (torch.Tensor): eigenvectors of A1, shape ``(n, rank)``.
80
+ v2 (torch.Tensor): vector such that ``v v^T = A2``, shape ``(n,)`` or ``(n, 1)``.
81
+ w1 (float): weight for A1.
82
+ w2 (float): weight for A2.
83
+ """
84
+ n = Q1.shape[0]
85
+ device = Q1.device
86
+ dtype = Q1.dtype
87
+ l = rank + oversampling_p
88
+
89
+ # gaussian test matrix
90
+ Omega = torch.randn(n, l, device=device, dtype=dtype)
91
+
92
+ # sketch
93
+ AOmega = weighted_eigen_plus_rank1_mm(L1, Q1, v2, Omega, w1, w2)
94
+ Q = orthogonalize(AOmega, orthogonalize_method)
95
+
96
+ AQ = weighted_eigen_plus_rank1_mm(L1, Q1, v2, Q, w1, w2)
97
+ QTAQ = Q.T @ AQ
98
+
99
+ W = (QTAQ + QTAQ.T) / 2.0
100
+
101
+ # compute new L and Q
102
+ try:
103
+ L_prime, S = torch_linalg.eigh(W, retry_float64=True)
104
+ except torch.linalg.LinAlgError:
105
+ return L1, Q1
106
+
107
+ L_prime, S = regularize_eigh(L=L_prime, Q=S, truncate=rank, tol=eig_tol, damping=damping, rdamping=rdamping)
108
+
109
+ if L_prime is None or S is None:
110
+ return L1, Q1
111
+
112
+ return L_prime, Q @ S
113
+
114
+
115
+ # def adanystrom_update2(
116
+ # L1: torch.Tensor,
117
+ # Q1: torch.Tensor,
118
+ # v2: torch.Tensor,
119
+ # w1: float,
120
+ # w2: float,
121
+ # rank: int,
122
+ # ):
123
+ # def A_mm(X):
124
+ # return weighted_eigen_plus_rank1_mm(L1=L1, Q1=Q1, v2=v2, B=X, w1=w1, w2=w2)
125
+
126
+ # return nystrom_approximation(A_mm, A_mm=A_mm, ndim=v2.numel(), rank=rank, device=L1.device, dtype=L1.dtype)
127
+
128
+ class AdaNystrom(TensorTransform):
129
+ """Adagrad/RMSprop/Adam with Nyström-approximated covariance matrix.
130
+
131
+ Args:
132
+ rank (_type_): rank of Nyström approximation.
133
+ w1 (float, optional): weight of current covariance matrix. Defaults to 0.95.
134
+ w2 (float, optional): weight of new gradient in covariance matrix. Defaults to 0.05.
135
+ oversampling (int, optional): number of extra random vectors (top rank eigenvalues are kept). Defaults to 10.
136
+ eig_tol (float, optional):
137
+ removes eigenvalues this much smaller than largest eigenvalue when updating the preconditioner. Defaults to 1e-7.
138
+ damping (float, optional):
139
+ added to eigenvalues when updating the preconditioner. Defaults to 1e-8.
140
+ rdamping (float, optional):
141
+ added to eigenvalues when updating the preconditioner, relative to largest eigenvalue. Defaults to 0.
142
+ mm_tol (float, optional):
143
+ removes eigenvalues this much smaller than largest eigenvalue when computing the update. Defaults to 1e-7.
144
+ mm_truncate (int | None, optional):
145
+ uses top k eigenvalues to compute the update. Defaults to None.
146
+ mm_damping (float, optional):
147
+ added to eigenvalues when computing the update. Defaults to 1e-4.
148
+ mm_rdamping (float, optional):
149
+ added to eigenvalues when computing the update, relative to largest eigenvalue. Defaults to 0.
150
+ id_reg (float, optional):
151
+ multiplier to identity matrix added to preconditioner before computing update
152
+ If this value is given, solution from Nyström sketch-and-solve will be used to compute the update.
153
+ This value can't be too small (i.e. less than 1e-5) or the solver will be very unstable. Defaults to None.
154
+ concat_params (bool, optional):
155
+ whether to precondition all parameters at once if True, or each separately if False. Defaults to True.
156
+ update_freq (int, optional): update frequency. Defaults to 1.
157
+ inner (Chainable | None, optional): inner modules. Defaults to None.
158
+ """
159
+ def __init__(
160
+ self,
161
+ rank:int = 100,
162
+ beta=0.95,
163
+ oversampling: int = 10,
164
+ eig_tol: float | None = 1e-32,
165
+ damping: float = 0,
166
+ rdamping: float = 0,
167
+ mm_tol: float = 0,
168
+ mm_truncate: int | None = None,
169
+ mm_damping: float = 0,
170
+ mm_rdamping: float = 0,
171
+ id_reg: float | None = None,
172
+ orthogonalize_method: OrthogonalizeMethod = 'qr',
173
+ eigenbasis_optimizer: LREOptimizerBase | None = None,
174
+ orthogonalize_interval: int | None = 100,
175
+
176
+ concat_params: bool = True,
177
+ update_freq: int = 1,
178
+ inner: Chainable | None = None,
179
+ ):
180
+ defaults = locals().copy()
181
+ for k in ["self", "concat_params", "inner", "update_freq"]:
182
+ del defaults[k]
183
+
184
+ super().__init__(defaults, concat_params=concat_params, inner=inner, update_freq=update_freq)
185
+
186
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
187
+ state["step"] = state.get("step", 0) + 1
188
+ rank = setting["rank"]
189
+ device = tensor.device
190
+ dtype = tensor.dtype
191
+ beta = setting["beta"]
192
+
193
+ try:
194
+ if "L" not in state:
195
+ # use just tensor and zero L and Q with zero weight
196
+
197
+ L, Q = adanystrom_update(
198
+ L1=torch.zeros(rank, device=device, dtype=dtype),
199
+ Q1=torch.zeros((tensor.numel(), rank), device=device, dtype=dtype),
200
+ v2=tensor.ravel(),
201
+ w1=0,
202
+ w2=1-beta,
203
+ rank=rank,
204
+ oversampling_p=setting["oversampling"],
205
+ eig_tol=setting["eig_tol"],
206
+ damping=setting["damping"],
207
+ rdamping=setting["rdamping"],
208
+ orthogonalize_method=setting["orthogonalize_method"],
209
+ )
210
+
211
+ state["L"] = state["L_reg"] = L
212
+ state["Q"] = state["Q_reg"] = Q
213
+
214
+ else:
215
+ L = state["L"]
216
+ Q = state["Q"]
217
+
218
+ w1 = beta
219
+ w2 = 1 - w1
220
+
221
+ # compute new factors (this function truncates them)
222
+ L_new, Q_new = adanystrom_update(
223
+ L1=L,
224
+ Q1=Q,
225
+ v2=tensor.ravel(),
226
+ w1=w1,
227
+ w2=w2,
228
+ rank=rank,
229
+ oversampling_p=setting["oversampling"],
230
+ eig_tol=setting["eig_tol"],
231
+ damping=setting["damping"],
232
+ rdamping=setting["rdamping"],
233
+ orthogonalize_method=setting["orthogonalize_method"],
234
+ )
235
+
236
+ _eigengrad_update_state_(state=state, setting=setting, L_new=L_new, Q_new=Q_new)
237
+
238
+ except torch.linalg.LinAlgError:
239
+ pass
240
+
241
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
242
+ if "L_reg" not in state:
243
+ return tensor.clip(-0.1, 0.1)
244
+
245
+ if "eigenbasis_state" not in state:
246
+ state["eigenbasis_state"] = {}
247
+
248
+ return eigengrad_apply(
249
+ tensor=tensor,
250
+ L_reg = state["L_reg"],
251
+ Q_reg = state["Q_reg"],
252
+ beta = setting["beta"],
253
+ step = state["step"],
254
+ debias = True,
255
+ id_reg = setting["id_reg"],
256
+ eigenbasis_optimizer = setting["eigenbasis_optimizer"],
257
+ eigenbasis_state = state["eigenbasis_state"]
258
+ )
@@ -0,0 +1,142 @@
1
+ from collections import deque
2
+ from typing import Literal
3
+
4
+ import torch
5
+
6
+ from torchzero.core import Chainable, TensorTransform
7
+ from torchzero.linalg import matrix_power_eigh, torch_linalg, orthogonalize, OrthogonalizeMethod, regularize_eigh
8
+ from torchzero.utils import TensorList, vec_to_tensors_
9
+
10
+
11
+ def update_subspace_preconditioner_(
12
+ grad: torch.Tensor, # store grads and basis as vectors for matmul
13
+ basis: torch.Tensor, # ndim, k
14
+ accumulator_: torch.Tensor, # k, k
15
+ beta: float | None,
16
+ ):
17
+ projected = basis.T @ grad # k
18
+ outer = torch.outer(projected, projected)
19
+
20
+ if beta is None: accumulator_.add_(outer)
21
+ else: accumulator_.lerp_(outer, 1-beta)
22
+
23
+ # yeah so I can also run subspace opts in this basis
24
+ def apply_subspace_preconditioner(
25
+ tensor: torch.Tensor,
26
+ basis: torch.Tensor, # ndim, k
27
+ accumulator: torch.Tensor,
28
+ tol: float,
29
+ truncate: int | None,
30
+ damping: float,
31
+ rdamping: float,
32
+ ):
33
+ L, Q = torch_linalg.eigh(accumulator, retry_float64=True)
34
+ L, Q = regularize_eigh(L=L, Q=Q, truncate=truncate, tol=tol, damping=damping, rdamping=rdamping)
35
+
36
+ if L is None or Q is None:
37
+ return tensor.clip(-0.1, 0.1)
38
+
39
+ preconditioner = (Q * L.rsqrt().unsqueeze(-2)) @ Q.mH
40
+
41
+ tensor_projected = basis.T @ tensor # k
42
+ update_projected = preconditioner @ tensor_projected # k
43
+ return basis @ update_projected # d
44
+
45
+
46
+ class CommonDirectionsWhiten(TensorTransform):
47
+ """Whitens in subspace spanned by history of gradient differences.
48
+
49
+ Args:
50
+ beta - for preconditioner itself in the basis.
51
+ basis_beta - how much basis is allowed to change.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ k: int = 100,
57
+ beta: float | None = 0.95,
58
+ basis_beta=0.95,
59
+ tol: float = 1e-7,
60
+ truncate: int | None = None,
61
+ damping: float = 1e-4,
62
+ rdamping: float = 0,
63
+ basis_type: Literal["gradients", "differences"] = "differences",
64
+ orthogonalize_method: OrthogonalizeMethod | None = 'newtonschulz',
65
+
66
+ concat_params: bool = True,
67
+ inner: Chainable | None = None,
68
+ ):
69
+ defaults = locals().copy()
70
+ for key in ["self", "inner", "concat_params"]:
71
+ del defaults[key]
72
+
73
+ super().__init__(defaults, concat_params=concat_params, inner=inner)
74
+
75
+ @torch.no_grad
76
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
77
+ g = tensor.ravel()
78
+ k = setting['k']
79
+ beta = setting['beta']
80
+ basis_beta = setting['basis_beta']
81
+ step = state.get("step", 0)
82
+ state["step"] = step + 1
83
+
84
+ # initialize history
85
+ if 'history' not in state:
86
+ state['history'] = deque(maxlen=k)
87
+ state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
88
+ state['basis'] = torch.zeros(g.numel(), k, device=g.device, dtype=g.dtype)
89
+
90
+ history: deque = state['history']
91
+ accumulator = state['accumulator']
92
+ basis = state['basis']
93
+ history.append(g)
94
+
95
+ # stack history to new basis term, if history isn't full, fill with random vecs
96
+ if len(history) < k:
97
+ basis_t = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
98
+ history_basis = torch.stack(tuple(history), -1)
99
+ basis_t[:, -len(history):] = history_basis
100
+
101
+ else:
102
+ basis_t = torch.stack(tuple(history), -1)
103
+
104
+ # in this case basis uses differences in gradients except last entry is the gradient
105
+ if setting["basis_type"] == "differences":
106
+ basis_t[:,:-1] = basis_t[:, :-1] - basis_t[:, 1:]
107
+
108
+ # normalize or orthonormalize new basis term
109
+ if setting["orthogonalize_method"] is not None:
110
+ basis_t = orthogonalize(basis_t, method = setting["orthogonalize_method"])
111
+ else:
112
+ basis_t = (basis_t - basis_t.mean()) / basis_t.std().clip(min=torch.finfo(g.dtype).tiny * 2)
113
+
114
+ # lerp basis
115
+ basis.lerp_(basis_t, 1-basis_beta)
116
+ basis = basis / (1 - basis_beta ** (step+1)) # correct bias on basis EMA
117
+ update_subspace_preconditioner_(g, basis, accumulator, beta)
118
+
119
+ @torch.no_grad
120
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
121
+ g = tensor.ravel()
122
+
123
+ basis = state['basis']
124
+ accumulator = state['accumulator']
125
+ step = state["step"]
126
+ accumulator = accumulator / (1 - setting["beta"] ** (step+1)) # correct bias on accumulator EMA
127
+
128
+ try:
129
+ preconditioned = apply_subspace_preconditioner(
130
+ g,
131
+ basis,
132
+ accumulator,
133
+ tol=setting["tol"],
134
+ truncate=setting["truncate"],
135
+ damping=setting["damping"],
136
+ rdamping=setting["rdamping"],
137
+ )
138
+ except torch.linalg.LinAlgError:
139
+ preconditioned = g.clip(-0.1, 0.1)
140
+
141
+ return preconditioned.view_as(tensor)
142
+
@@ -0,0 +1,36 @@
1
+ import torch
2
+
3
+ from ...core import TensorTransform
4
+ from ...utils import NumberList, TensorList, unpack_states
5
+
6
+
7
+ def coordinate_momentum_(
8
+ tensors: TensorList,
9
+ velocity_: TensorList,
10
+ p: float | NumberList,
11
+ ):
12
+ """
13
+ sets `velocity_` to p% random values from `tensors`.
14
+
15
+ Returns `velocity_`
16
+ """
17
+ mask = tensors.bernoulli_like(p).as_bool()
18
+ velocity_.masked_set_(mask, tensors)
19
+ return velocity_
20
+
21
+
22
+ class CoordinateMomentum(TensorTransform):
23
+ """Maintains a momentum buffer, on each step each value in the buffer has ``p`` chance to be updated with the new value.
24
+
25
+ Args:
26
+ p (float, optional): _description_. Defaults to 0.1.
27
+ """
28
+ def __init__(self, p: float = 0.1):
29
+ defaults = dict(p=p)
30
+ super().__init__(defaults)
31
+
32
+ @torch.no_grad
33
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
34
+ p = NumberList(s['p'] for s in settings)
35
+ velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
36
+ return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()