torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,49 +1,20 @@
1
- import math
2
- from collections.abc import Callable
3
1
  from typing import Literal
4
2
 
5
3
  import torch
6
4
 
7
- from ...core import Chainable, Module, Target, Transform, apply_transform
8
- from ...utils import NumberList, TensorList, as_tensorlist
5
+ from ...core import Chainable, HVPMethod, Transform
6
+ from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
9
7
 
10
8
 
11
- def esgd_(
12
- tensors_: TensorList,
13
- D: TensorList | None,
14
- D_sq_acc_: TensorList,
15
- damping: float | NumberList,
16
- update_freq: int,
17
- step: int,
18
- i: int,
19
- ):
20
- # update preconditioner
21
- if step % update_freq == 0:
22
- assert D is not None
23
- D_sq_acc_.addcmul_(D, D)
24
- i += 1
25
- else:
26
- assert D is None
27
-
28
- denom = (D_sq_acc_ / max(i, 1)).sqrt_().add_(damping)
29
- return tensors_.div_(denom), i
30
-
31
-
32
- class ESGD(Module):
9
+ class ESGD(Transform):
33
10
  """Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)
34
11
 
35
12
  This is similar to Adagrad, but the accumulates squared randomized hessian diagonal estimates instead of squared gradients.
36
13
 
37
- .. note::
38
- In most cases Adagrad should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Adagrad preconditioning to another module's output.
14
+ Notes:
15
+ - In most cases ESGD should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply ESGD preconditioning to another module's output.
39
16
 
40
- .. note::
41
- If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
42
-
43
- .. note::
44
- This module requires a closure passed to the optimizer step,
45
- as it needs to re-evaluate the loss and gradients for calculating HVPs.
46
- The closure must accept a ``backward`` argument (refer to documentation).
17
+ - This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
47
18
 
48
19
  Args:
49
20
  damping (float, optional): added to denominator for stability. Defaults to 1e-4.
@@ -51,17 +22,17 @@ class ESGD(Module):
51
22
  frequency of updating hessian diagonal estimate via a hessian-vector product.
52
23
  This value can be increased to reduce computational cost. Defaults to 20.
53
24
  hvp_method (str, optional):
54
- Determines how Hessian-vector products are evaluated.
55
-
56
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
57
- This requires creating a graph for the gradient.
58
- - ``"forward"``: Use a forward finite difference formula to
59
- approximate the HVP. This requires one extra gradient evaluation.
60
- - ``"central"``: Use a central finite difference formula for a
61
- more accurate HVP approximation. This requires two extra
62
- gradient evaluations.
63
- Defaults to "autograd".
64
- fd_h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
25
+ Determines how hessian-vector products are computed.
26
+
27
+ - ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
28
+ - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
29
+ - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
30
+ - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
31
+
32
+ Defaults to ``"autograd"``.
33
+ h (float, optional):
34
+ The step size for finite difference if ``hvp_method`` is
35
+ ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
65
36
  n_samples (int, optional):
66
37
  number of hessian-vector products with random vectors to evaluate each time when updating
67
38
  the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
@@ -72,100 +43,108 @@ class ESGD(Module):
72
43
  2. pass inputs to :code:`inner`.
73
44
  3. momentum and preconditioning are applied to the ouputs of :code:`inner`.
74
45
 
75
- Examples:
76
- Using ESGD:
77
-
78
- .. code-block:: python
46
+ ### Examples:
79
47
 
80
- opt = tz.Modular(
81
- model.parameters(),
82
- tz.m.ESGD(),
83
- tz.m.LR(0.1)
84
- )
48
+ Using ESGD:
49
+ ```python
85
50
 
86
- ESGD preconditioner can be applied to any other module by passing it to the :code:`inner` argument. Here is an example of applying
87
- ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):
51
+ opt = tz.Modular(
52
+ model.parameters(),
53
+ tz.m.ESGD(),
54
+ tz.m.LR(0.1)
55
+ )
56
+ ```
88
57
 
89
- .. code-block:: python
58
+ ESGD preconditioner can be applied to any other module by passing it to the :code:`inner` argument. Here is an example of applying
59
+ ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):
90
60
 
91
- opt = tz.Modular(
92
- model.parameters(),
93
- tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
94
- tz.m.LR(0.1)
95
- )
61
+ ```python
62
+ opt = tz.Modular(
63
+ model.parameters(),
64
+ tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
65
+ tz.m.LR(0.1)
66
+ )
67
+ ```
96
68
 
97
69
  """
98
70
  def __init__(
99
71
  self,
100
72
  damping: float = 1e-4,
101
73
  update_freq: int = 20,
102
- hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
103
- fd_h: float = 1e-3,
74
+ distribution: Distributions = 'gaussian',
75
+ hvp_method: HVPMethod = 'autograd',
76
+ h: float = 1e-3,
104
77
  n_samples = 1,
78
+ zHz: bool = False,
105
79
  seed: int | None = None,
106
- inner: Chainable | None = None
80
+ beta: float | None = None,
81
+ beta_debias: bool = True,
82
+
83
+ inner: Chainable | None = None,
84
+ Hz_sq_acc_tfm: Chainable | None = None,
107
85
  ):
108
- defaults = dict(damping=damping, update_freq=update_freq, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
109
- super().__init__(defaults)
86
+ defaults = locals().copy()
87
+ del defaults['self'], defaults['inner'], defaults["Hz_sq_acc_tfm"]
88
+ super().__init__(defaults, inner=inner)
110
89
 
111
- if inner is not None:
112
- self.set_child('inner', inner)
90
+ self.set_child("Hz_sq_acc", Hz_sq_acc_tfm)
113
91
 
114
92
  @torch.no_grad
115
- def step(self, var):
116
- params = var.params
117
- settings = self.settings[params[0]]
118
- hvp_method = settings['hvp_method']
119
- fd_h = settings['fd_h']
120
- update_freq = settings['update_freq']
121
- n_samples = settings['n_samples']
122
-
123
- seed = settings['seed']
124
- generator = None
125
- if seed is not None:
126
- if 'generator' not in self.global_state:
127
- self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
128
- generator = self.global_state['generator']
129
-
130
- damping = self.get_settings(params, 'damping', cls=NumberList)
131
- D_sq_acc = self.get_state(params, 'D_sq_acc', cls=TensorList)
132
- i = self.global_state.get('i', 0)
133
-
134
- step = self.global_state.get('step', 0)
135
- self.global_state['step'] = step + 1
136
-
137
- closure = var.closure
138
- assert closure is not None
139
-
140
- D = None
93
+ def update_states(self, objective, states, settings):
94
+ params = objective.params
95
+
96
+ fs = settings[0]
97
+ update_freq = fs['update_freq']
98
+
99
+ # ------------------------------- accumulate Hz ------------------------------ #
100
+ step = self.increment_counter("step", start=0)
101
+
141
102
  if step % update_freq == 0:
103
+ self.increment_counter("num_Hzs", start=1)
104
+
105
+ Hz, _ = objective.hutchinson_hessian(
106
+ rgrad = None,
107
+ at_x0 = True,
108
+ n_samples = fs['n_samples'],
109
+ distribution = fs['distribution'],
110
+ hvp_method = fs['hvp_method'],
111
+ h = fs['h'],
112
+ zHz = fs["zHz"], # default is False, so it returns Hz, not z⊙Hz
113
+ generator = self.get_generator(params[0].device, fs["seed"]),
114
+ )
142
115
 
143
- rgrad=None
144
- for j in range(n_samples):
145
- u = [torch.randn(p.size(), generator=generator, device=p.device, dtype=p.dtype) for p in params]
116
+ Hz = TensorList(Hz)
117
+ Hz_sq_acc = unpack_states(states, params, 'Hz_sq_acc', cls=TensorList)
118
+
119
+ beta = fs["beta"]
120
+ if beta is None:
121
+ Hz_sq_acc.addcmul_(Hz, Hz)
122
+
123
+ else:
124
+ Hz_sq_acc.mul_(beta).addcmul_(Hz, Hz, value=1-beta)
125
+
126
+ @torch.no_grad
127
+ def apply_states(self, objective, states, settings):
128
+ tensors = TensorList(objective.get_updates())
129
+ Hz_sq_acc = unpack_states(states, tensors, 'Hz_sq_acc', cls=TensorList)
130
+ num_Hzs = self.global_state["num_Hzs"]
131
+ fs = settings[0]
146
132
 
147
- Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
148
- h=fd_h, normalize=True, retain_grad=j < n_samples-1)
133
+ # ---------------------------------- debias ---------------------------------- #
134
+ beta = fs["beta"]
135
+ beta_debias = fs["beta_debias"]
149
136
 
150
- if D is None: D = Hvp
151
- else: torch._foreach_add_(D, Hvp)
137
+ if beta_debias and beta is not None:
138
+ bias_correction = 1.0 - beta ** num_Hzs
139
+ Hz_sq_acc = Hz_sq_acc / bias_correction
152
140
 
153
- assert D is not None
154
- if n_samples > 1: torch._foreach_div_(D, n_samples)
141
+ else:
142
+ Hz_sq_acc = Hz_sq_acc / num_Hzs
155
143
 
156
- D = TensorList(D)
144
+ # ---------------------------------- update ---------------------------------- #
145
+ damping = [s["damping"] for s in settings]
157
146
 
158
- update = var.get_update()
159
- if 'inner' in self.children:
160
- update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
147
+ denom = (Hz_sq_acc / num_Hzs).sqrt_().add_(damping)
161
148
 
162
- var.update, self.global_state['i'] = esgd_(
163
- tensors_=TensorList(update),
164
- D=TensorList(D) if D is not None else None,
165
- D_sq_acc_=D_sq_acc,
166
- damping=damping,
167
- update_freq=update_freq,
168
- step=step,
169
- i=i,
170
- )
171
- return var
149
+ objective.updates = tensors.div_(denom)
150
+ return objective
@@ -1,21 +1,16 @@
1
1
  import torch
2
2
 
3
- from ...core import Module, Target, Transform
3
+ from ...core import TensorTransform
4
4
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
5
5
 
6
6
 
7
7
  def lion_(tensors: TensorList, exp_avg_: TensorList, beta1, beta2,):
8
- """
9
- Lion update rule.
10
-
11
- Returns new tensors.
12
- """
13
8
  update = exp_avg_.lerp(tensors, 1-beta1).sign_()
14
9
  exp_avg_.lerp_(tensors, 1-beta2)
15
10
  return update
16
11
 
17
12
 
18
- class Lion(Transform):
13
+ class Lion(TensorTransform):
19
14
  """Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.
20
15
 
21
16
  Args:
@@ -25,11 +20,11 @@ class Lion(Transform):
25
20
 
26
21
  def __init__(self, beta1: float = 0.9, beta2: float = 0.99):
27
22
  defaults = dict(beta1=beta1, beta2=beta2)
28
- super().__init__(defaults, uses_grad=False)
23
+ super().__init__(defaults)
29
24
 
30
25
  @torch.no_grad
31
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
26
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
32
27
  beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
33
28
  exp_avg = unpack_states(states, tensors, 'ema', cls=TensorList)
34
- return lion_(TensorList(tensors),exp_avg,beta1,beta2)
29
+ return lion_(TensorList(tensors), exp_avg, beta1, beta2)
35
30
 
@@ -3,9 +3,11 @@ from typing import Literal, Any
3
3
  import warnings
4
4
 
5
5
  import torch
6
- from ...core import Chainable, TensorwiseTransform
6
+ from ...core import Chainable, TensorTransform
7
+ from ...linalg import torch_linalg
7
8
 
8
- def lm_adagrad_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping):
9
+ def lm_adagrad_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping, truncate, tol):
10
+ """returns U ``(ndim, rank)``, L ``(rank, )``"""
9
11
  if isinstance(history, torch.Tensor):
10
12
  M = history
11
13
  else:
@@ -16,35 +18,49 @@ def lm_adagrad_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdam
16
18
  MTM.add_(torch.eye(MTM.size(0), device=MTM.device, dtype=MTM.dtype).mul_(damping))
17
19
 
18
20
  try:
19
- L, Q = torch.linalg.eigh(MTM) # pylint:disable=not-callable
21
+ L, Q = torch_linalg.eigh(MTM, retry_float64=True)
20
22
 
21
- tol = torch.finfo(M.dtype).eps * L.amax() # remove small eigenvalues
22
- indices = L > tol
23
- L = L[indices]
24
- Q = Q[:, indices]
23
+ # truncate to top n largest eigenvalues
24
+ if truncate is not None and truncate > 0:
25
+ # L is ordered in ascending order
26
+ L = L[-truncate:]
27
+ Q = Q[:, -truncate:]
28
+
29
+ # remove small eigenvalues relative to largest
30
+ L_max = L.amax()
31
+ indices = L > tol * L_max
32
+ if indices.any():
33
+ L = L[indices]
34
+ Q = Q[:, indices]
25
35
 
26
36
  U = (M @ Q) * L.rsqrt()
27
37
 
28
38
  if rdamping != 0:
29
- rdamping *= torch.linalg.vector_norm(L) # pylint:disable=not-callable
30
- L.add_(rdamping)
39
+ L.add_(rdamping * L_max)
31
40
 
32
41
  return U, L
33
42
 
34
43
  except torch.linalg.LinAlgError:
35
44
  return None, None
36
45
 
37
- def lm_adagrad_apply(g: torch.Tensor, U: torch.Tensor, L: torch.Tensor):
38
- Z = U.T @ g
39
- return (U * L.rsqrt()) @ Z
46
+ def lm_adagrad_apply(g: torch.Tensor, U: torch.Tensor, L: torch.Tensor, exp_avg_proj: torch.Tensor | None, beta:float):
47
+ z = U.T @ g
48
+
49
+ if beta != 0:
50
+ if exp_avg_proj is None: exp_avg_proj = torch.zeros_like(z)
51
+ exp_avg_proj.lerp_(z, weight=1-beta)
52
+ z = exp_avg_proj
53
+
54
+ return (U * L.rsqrt()) @ z, exp_avg_proj
40
55
 
41
56
  def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
42
- if (key not in state_) or (beta is None) or (not isinstance(value, torch.Tensor)): state_[key] = value
57
+ if value is None: return
58
+ if (key not in state_) or (beta is None): state_[key] = value
43
59
  else:
44
60
  if state_[key] is None or state_[key].shape != value.shape: state_[key] = value
45
61
  else: state_[key].lerp_(value, 1-beta)
46
62
 
47
- class LMAdagrad(TensorwiseTransform):
63
+ class LMAdagrad(TensorTransform):
48
64
  """
49
65
  Limited-memory full matrix Adagrad.
50
66
 
@@ -55,17 +71,18 @@ class LMAdagrad(TensorwiseTransform):
55
71
 
56
72
  Args:
57
73
  history_size (int, optional): number of past gradients to store. Defaults to 10.
74
+ beta (float, optional): beta for momentum maintained in whitened space. Defaults to 0.0.
58
75
  update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
59
76
  damping (float, optional): damping value. Defaults to 1e-4.
60
77
  rdamping (float, optional): value of damping relative to singular values norm. Defaults to 0.
78
+ rdamping (float, optional): value of damping relative to singular values norm. Defaults to 0.
79
+ truncate (int, optional): number of larges eigenvalues to keep. None to disable. Defaults to None.
80
+ tol (float, optional): removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.
61
81
  order (int, optional):
62
82
  order=2 means gradient differences are used in place of gradients. Higher order uses higher order differences. Defaults to 1.
63
- true_damping (bool, optional):
64
- If True, damping is added to squared singular values to mimic Adagrad. Defaults to True.
65
83
  U_beta (float | None, optional): momentum for U (too unstable, don't use). Defaults to None.
66
84
  L_beta (float | None, optional): momentum for L (too unstable, don't use). Defaults to None.
67
- interval (int, optional): Interval between gradients that are added to history (2 means every second gradient is used). Defaults to 1.
68
- concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to True.
85
+ concat_params (bool, optional): if True, treats all parameters as a single vector. Defaults to True.
69
86
  inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
70
87
 
71
88
  ## Examples:
@@ -108,28 +125,35 @@ class LMAdagrad(TensorwiseTransform):
108
125
  def __init__(
109
126
  self,
110
127
  history_size: int = 100,
128
+ beta: float = 0.0,
111
129
  update_freq: int = 1,
112
130
  damping: float = 1e-4,
113
131
  rdamping: float = 0,
132
+ truncate: int | None = None,
133
+ tol: float = 1e-7,
114
134
  order: int = 1,
115
- true_damping: bool = True,
116
135
  U_beta: float | None = None,
117
136
  L_beta: float | None = None,
118
- interval: int = 1,
119
137
  concat_params: bool = True,
138
+
120
139
  inner: Chainable | None = None,
140
+ U_tfm: Chainable | None = None,
141
+ L_tfm: Chainable | None = None,
121
142
  ):
122
- # history is still updated each step so Precondition's update_freq has different meaning
123
- defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, rdamping=rdamping, true_damping=true_damping, order=order, U_beta=U_beta, L_beta=L_beta)
124
- super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner, update_freq=interval)
143
+ defaults = locals().copy()
144
+ del defaults['self'], defaults['inner'], defaults['concat_params'], defaults["U_tfm"], defaults["L_tfm"]
145
+
146
+ super().__init__(defaults, concat_params=concat_params, inner=inner)
147
+
148
+ self.set_child("U", U_tfm)
149
+ self.set_child("L", L_tfm)
150
+
125
151
 
126
152
  @torch.no_grad
127
- def update_tensor(self, tensor, param, grad, loss, state, setting):
153
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
128
154
  order = setting['order']
129
155
  history_size = setting['history_size']
130
156
  update_freq = setting['update_freq']
131
- damping = setting['damping']
132
- rdamping = setting['rdamping']
133
157
  U_beta = setting['U_beta']
134
158
  L_beta = setting['L_beta']
135
159
 
@@ -165,22 +189,53 @@ class LMAdagrad(TensorwiseTransform):
165
189
 
166
190
  step = state.get('step', 0)
167
191
  if step % update_freq == 0 and len(history) != 0:
168
- U, L = lm_adagrad_update(history, damping=damping, rdamping=rdamping)
192
+
193
+ # if maintaining momentum, unproject exp_avg before updating factors and reproject
194
+ exp_avg_proj = state.get("exp_avg_proj", None)
195
+ exp_avg = None
196
+ if exp_avg_proj is not None and "U" in state:
197
+ exp_avg = state["U"] @ exp_avg_proj
198
+
199
+ # update factors
200
+ U, L = lm_adagrad_update(
201
+ history,
202
+ damping=setting["damping"],
203
+ rdamping=setting["rdamping"],
204
+ truncate=setting["truncate"],
205
+ tol=setting["tol"],
206
+ )
169
207
  maybe_lerp_(state, U_beta, 'U', U)
170
208
  maybe_lerp_(state, L_beta, 'L', L)
171
209
 
210
+ # re-project exp_avg with new factors
211
+ if U is not None and exp_avg_proj is not None:
212
+ assert exp_avg is not None
213
+ state["exp_avg_proj"] = U.T @ exp_avg
214
+
215
+
172
216
  if len(history) != 0:
173
217
  state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)
174
218
 
175
219
  @torch.no_grad
176
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
220
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
177
221
  U = state.get('U', None)
178
222
  if U is None:
179
223
  # make a conservative step to avoid issues due to different GD scaling
180
- return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
224
+ return tensor.clip_(-0.1, 0.1)
181
225
 
226
+ # -------------------------------- transforms -------------------------------- #
182
227
  L = state['L']
183
- update = lm_adagrad_apply(tensor.view(-1), U, L).view_as(tensor)
184
-
185
- return update
228
+ if "L" in self.children:
229
+ if not self._concat_params: raise RuntimeError("L/U transforms can only be used with concat_params=True")
230
+ L = self.inner_step_tensors("L", [L], clone=True)[0]
231
+
232
+ if "U" in self.children:
233
+ if not self._concat_params: raise RuntimeError("L/U transforms can only be used with concat_params=True")
234
+ U = self.inner_step_tensors("U", [U], clone=True)[0]
235
+
236
+ # ------------------------------- precondition ------------------------------- #
237
+ g = tensor.view(-1)
238
+ exp_avg_proj = state.get("exp_avg_proj", None)
239
+ update, state["exp_avg_proj"] = lm_adagrad_apply(g, U, L, exp_avg_proj, beta=setting["beta"])
240
+ return update.view_as(tensor)
186
241
 
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
 
3
- from ...core import Transform
3
+ from ...core import TensorTransform
4
4
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
5
5
 
6
6
 
@@ -20,7 +20,7 @@ def mars_correction_(
20
20
 
21
21
  return c
22
22
 
23
- class MARSCorrection(Transform):
23
+ class MARSCorrection(TensorTransform):
24
24
  """MARS variance reduction correction.
25
25
 
26
26
  Place any other momentum-based optimizer after this,
@@ -61,11 +61,11 @@ class MARSCorrection(Transform):
61
61
  scaling: float = 0.025,
62
62
  max_norm: float | None = 1,
63
63
  ):
64
- defaults=dict(beta=beta, scaling=scaling, max_norm=max_norm)
65
- super().__init__(defaults, uses_grad=False)
64
+ defaults = dict(beta=beta, scaling=scaling, max_norm=max_norm)
65
+ super().__init__(defaults)
66
66
 
67
67
  @torch.no_grad
68
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
68
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
69
69
  prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
70
70
  beta, scaling = unpack_dicts(settings, 'beta', 'scaling', cls=NumberList)
71
71
  max_norm = settings[0]['max_norm']