torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,49 +1,20 @@
1
- import math
2
- from collections.abc import Callable
3
1
  from typing import Literal
4
2
 
5
3
  import torch
6
4
 
7
- from ...core import Chainable, Module, Target, Transform, apply_transform
8
- from ...utils import NumberList, TensorList, as_tensorlist
5
+ from ...core import Chainable, HVPMethod, Transform
6
+ from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
9
7
 
10
8
 
11
- def esgd_(
12
- tensors_: TensorList,
13
- D: TensorList | None,
14
- D_sq_acc_: TensorList,
15
- damping: float | NumberList,
16
- update_freq: int,
17
- step: int,
18
- i: int,
19
- ):
20
- # update preconditioner
21
- if step % update_freq == 0:
22
- assert D is not None
23
- D_sq_acc_.addcmul_(D, D)
24
- i += 1
25
- else:
26
- assert D is None
27
-
28
- denom = (D_sq_acc_ / max(i, 1)).sqrt_().add_(damping)
29
- return tensors_.div_(denom), i
30
-
31
-
32
- class ESGD(Module):
9
+ class ESGD(Transform):
33
10
  """Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)
34
11
 
35
12
  This is similar to Adagrad, but the accumulates squared randomized hessian diagonal estimates instead of squared gradients.
36
13
 
37
- .. note::
38
- In most cases Adagrad should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Adagrad preconditioning to another module's output.
14
+ Notes:
15
+ - In most cases ESGD should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply ESGD preconditioning to another module's output.
39
16
 
40
- .. note::
41
- If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
42
-
43
- .. note::
44
- This module requires a closure passed to the optimizer step,
45
- as it needs to re-evaluate the loss and gradients for calculating HVPs.
46
- The closure must accept a ``backward`` argument (refer to documentation).
17
+ - This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
47
18
 
48
19
  Args:
49
20
  damping (float, optional): added to denominator for stability. Defaults to 1e-4.
@@ -51,17 +22,17 @@ class ESGD(Module):
51
22
  frequency of updating hessian diagonal estimate via a hessian-vector product.
52
23
  This value can be increased to reduce computational cost. Defaults to 20.
53
24
  hvp_method (str, optional):
54
- Determines how Hessian-vector products are evaluated.
55
-
56
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
57
- This requires creating a graph for the gradient.
58
- - ``"forward"``: Use a forward finite difference formula to
59
- approximate the HVP. This requires one extra gradient evaluation.
60
- - ``"central"``: Use a central finite difference formula for a
61
- more accurate HVP approximation. This requires two extra
62
- gradient evaluations.
63
- Defaults to "autograd".
64
- fd_h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
25
+ Determines how hessian-vector products are computed.
26
+
27
+ - ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
28
+ - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
29
+ - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
30
+ - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
31
+
32
+ Defaults to ``"autograd"``.
33
+ h (float, optional):
34
+ The step size for finite difference if ``hvp_method`` is
35
+ ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
65
36
  n_samples (int, optional):
66
37
  number of hessian-vector products with random vectors to evaluate each time when updating
67
38
  the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
@@ -72,100 +43,108 @@ class ESGD(Module):
72
43
  2. pass inputs to :code:`inner`.
73
44
  3. momentum and preconditioning are applied to the ouputs of :code:`inner`.
74
45
 
75
- Examples:
76
- Using ESGD:
77
-
78
- .. code-block:: python
46
+ ### Examples:
79
47
 
80
- opt = tz.Modular(
81
- model.parameters(),
82
- tz.m.ESGD(),
83
- tz.m.LR(0.1)
84
- )
48
+ Using ESGD:
49
+ ```python
85
50
 
86
- ESGD preconditioner can be applied to any other module by passing it to the :code:`inner` argument. Here is an example of applying
87
- ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):
51
+ opt = tz.Optimizer(
52
+ model.parameters(),
53
+ tz.m.ESGD(),
54
+ tz.m.LR(0.1)
55
+ )
56
+ ```
88
57
 
89
- .. code-block:: python
58
+ ESGD preconditioner can be applied to any other module by passing it to the :code:`inner` argument. Here is an example of applying
59
+ ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):
90
60
 
91
- opt = tz.Modular(
92
- model.parameters(),
93
- tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
94
- tz.m.LR(0.1)
95
- )
61
+ ```python
62
+ opt = tz.Optimizer(
63
+ model.parameters(),
64
+ tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
65
+ tz.m.LR(0.1)
66
+ )
67
+ ```
96
68
 
97
69
  """
98
70
  def __init__(
99
71
  self,
100
72
  damping: float = 1e-4,
101
73
  update_freq: int = 20,
102
- hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
103
- fd_h: float = 1e-3,
74
+ distribution: Distributions = 'gaussian',
75
+ hvp_method: HVPMethod = 'autograd',
76
+ h: float = 1e-3,
104
77
  n_samples = 1,
78
+ zHz: bool = False,
105
79
  seed: int | None = None,
106
- inner: Chainable | None = None
80
+ beta: float | None = None,
81
+ beta_debias: bool = True,
82
+
83
+ inner: Chainable | None = None,
84
+ Hz_sq_acc_tfm: Chainable | None = None,
107
85
  ):
108
- defaults = dict(damping=damping, update_freq=update_freq, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
109
- super().__init__(defaults)
86
+ defaults = locals().copy()
87
+ del defaults['self'], defaults['inner'], defaults["Hz_sq_acc_tfm"]
88
+ super().__init__(defaults, inner=inner)
110
89
 
111
- if inner is not None:
112
- self.set_child('inner', inner)
90
+ self.set_child("Hz_sq_acc", Hz_sq_acc_tfm)
113
91
 
114
92
  @torch.no_grad
115
- def step(self, var):
116
- params = var.params
117
- settings = self.settings[params[0]]
118
- hvp_method = settings['hvp_method']
119
- fd_h = settings['fd_h']
120
- update_freq = settings['update_freq']
121
- n_samples = settings['n_samples']
122
-
123
- seed = settings['seed']
124
- generator = None
125
- if seed is not None:
126
- if 'generator' not in self.global_state:
127
- self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
128
- generator = self.global_state['generator']
129
-
130
- damping = self.get_settings(params, 'damping', cls=NumberList)
131
- D_sq_acc = self.get_state(params, 'D_sq_acc', cls=TensorList)
132
- i = self.global_state.get('i', 0)
133
-
134
- step = self.global_state.get('step', 0)
135
- self.global_state['step'] = step + 1
136
-
137
- closure = var.closure
138
- assert closure is not None
139
-
140
- D = None
93
+ def update_states(self, objective, states, settings):
94
+ params = objective.params
95
+
96
+ fs = settings[0]
97
+ update_freq = fs['update_freq']
98
+
99
+ # ------------------------------- accumulate Hz ------------------------------ #
100
+ step = self.increment_counter("step", start=0)
101
+
141
102
  if step % update_freq == 0:
103
+ self.increment_counter("num_Hzs", start=1)
104
+
105
+ Hz, _ = objective.hutchinson_hessian(
106
+ rgrad = None,
107
+ at_x0 = True,
108
+ n_samples = fs['n_samples'],
109
+ distribution = fs['distribution'],
110
+ hvp_method = fs['hvp_method'],
111
+ h = fs['h'],
112
+ zHz = fs["zHz"], # default is False, so it returns Hz, not z⊙Hz
113
+ generator = self.get_generator(params[0].device, fs["seed"]),
114
+ )
142
115
 
143
- rgrad=None
144
- for j in range(n_samples):
145
- u = [torch.randn(p.size(), generator=generator, device=p.device, dtype=p.dtype) for p in params]
116
+ Hz = TensorList(Hz)
117
+ Hz_sq_acc = unpack_states(states, params, 'Hz_sq_acc', cls=TensorList)
118
+
119
+ beta = fs["beta"]
120
+ if beta is None:
121
+ Hz_sq_acc.addcmul_(Hz, Hz)
122
+
123
+ else:
124
+ Hz_sq_acc.mul_(beta).addcmul_(Hz, Hz, value=1-beta)
125
+
126
+ @torch.no_grad
127
+ def apply_states(self, objective, states, settings):
128
+ tensors = TensorList(objective.get_updates())
129
+ Hz_sq_acc = unpack_states(states, tensors, 'Hz_sq_acc', cls=TensorList)
130
+ num_Hzs = self.global_state["num_Hzs"]
131
+ fs = settings[0]
146
132
 
147
- Hvp, rgrad = var.hessian_vector_product(u, at_x0=True, rgrad=rgrad, hvp_method=hvp_method,
148
- h=fd_h, normalize=True, retain_graph=j < n_samples-1)
133
+ # ---------------------------------- debias ---------------------------------- #
134
+ beta = fs["beta"]
135
+ beta_debias = fs["beta_debias"]
149
136
 
150
- if D is None: D = Hvp
151
- else: torch._foreach_add_(D, Hvp)
137
+ if beta_debias and beta is not None:
138
+ bias_correction = 1.0 - beta ** num_Hzs
139
+ Hz_sq_acc = Hz_sq_acc / bias_correction
152
140
 
153
- assert D is not None
154
- if n_samples > 1: torch._foreach_div_(D, n_samples)
141
+ else:
142
+ Hz_sq_acc = Hz_sq_acc / num_Hzs
155
143
 
156
- D = TensorList(D)
144
+ # ---------------------------------- update ---------------------------------- #
145
+ damping = [s["damping"] for s in settings]
157
146
 
158
- update = var.get_update()
159
- if 'inner' in self.children:
160
- update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
147
+ denom = (Hz_sq_acc / num_Hzs).sqrt_().add_(damping)
161
148
 
162
- var.update, self.global_state['i'] = esgd_(
163
- tensors_=TensorList(update),
164
- D=TensorList(D) if D is not None else None,
165
- D_sq_acc_=D_sq_acc,
166
- damping=damping,
167
- update_freq=update_freq,
168
- step=step,
169
- i=i,
170
- )
171
- return var
149
+ objective.updates = tensors.div_(denom)
150
+ return objective
@@ -0,0 +1,186 @@
1
+ from collections import deque
2
+ from typing import Literal, Any
3
+ import warnings
4
+
5
+ import torch
6
+ from ...core import Chainable, TensorTransform
7
+ from ...linalg import torch_linalg, regularize_eigh
8
+ from .lre_optimizers import LREOptimizerBase
9
+
10
+ def ggt_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping, truncate, eig_tol):
11
+ """returns U ``(ndim, rank)``, L ``(rank, )``"""
12
+ if isinstance(history, torch.Tensor):
13
+ M = history
14
+ else:
15
+ M = torch.stack(tuple(history), dim=1)# / len(history)
16
+
17
+ MtM = M.T @ M
18
+ if damping != 0:
19
+ MtM.add_(torch.eye(MtM.size(0), device=MtM.device, dtype=MtM.dtype).mul_(damping))
20
+
21
+ try:
22
+ L, Q = torch_linalg.eigh(MtM, retry_float64=True)
23
+
24
+ # damping is already added to MTM, rdamping is added afterwards
25
+ L, Q = regularize_eigh(L, Q, truncate=truncate, tol=eig_tol, damping=0, rdamping=0)
26
+
27
+ if L is None or Q is None: # this means there are no finite eigenvalues
28
+ return None, None
29
+
30
+ U = (M @ Q) * L.rsqrt()
31
+
32
+ # this damping is added after computing U, this is why I didn't use one in linalg.regularize_eig
33
+ # that's because we damp singular values this way
34
+ if rdamping != 0:
35
+ L.add_(rdamping * L[-1]) # L is sorted in ascending order
36
+
37
+ return L, U
38
+
39
+ except torch.linalg.LinAlgError:
40
+ return None, None
41
+
42
+
43
+ class GGT(TensorTransform):
44
+ """
45
+ GGT method from https://arxiv.org/pdf/1806.02958
46
+
47
+ The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg.
48
+ But it uses eigendecomposition on MᵀM to get U and S^2 because that is faster when you don't neeed V.
49
+
50
+ This is equivalent to full-matrix Adagrad on recent gradients.
51
+
52
+ Args:
53
+ history_size (int, optional): number of past gradients to store. Defaults to 10.
54
+ beta (float, optional): beta for momentum maintained in whitened space. Defaults to 0.0.
55
+ update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
56
+ eig_tol (float, optional): removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.
57
+ truncate (int, optional): number of larges eigenvalues to keep. None to disable. Defaults to None.
58
+ damping (float, optional): damping value. Defaults to 1e-4.
59
+ rdamping (float, optional): value of damping relative to largest eigenvalue. Defaults to 0.
60
+ concat_params (bool, optional): if True, treats all parameters as a single vector. Defaults to True.
61
+ inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
62
+
63
+ ## Examples:
64
+
65
+ Limited-memory Adagrad
66
+
67
+ ```python
68
+ optimizer = tz.Optimizer(
69
+ model.parameters(),
70
+ tz.m.GGT(),
71
+ tz.m.LR(0.1)
72
+ )
73
+ ```
74
+ Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)
75
+
76
+ ```python
77
+ optimizer = tz.Optimizer(
78
+ model.parameters(),
79
+ tz.m.GGT(inner=tz.m.EMA()),
80
+ tz.m.Debias(0.9, 0.999),
81
+ tz.m.LR(0.01)
82
+ )
83
+ ```
84
+
85
+ Stable Adam with L-Adagrad preconditioner (this is what I would recommend)
86
+
87
+ ```python
88
+ optimizer = tz.Optimizer(
89
+ model.parameters(),
90
+ tz.m.GGT(inner=tz.m.EMA()),
91
+ tz.m.Debias(0.9, 0.999),
92
+ tz.m.ClipNormByEMA(max_ema_growth=1.2),
93
+ tz.m.LR(0.01)
94
+ )
95
+ ```
96
+ Reference:
97
+ Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ history_size: int = 100,
103
+ update_freq: int = 1,
104
+ eig_tol: float = 1e-7,
105
+ truncate: int | None = None,
106
+ damping: float = 1e-4,
107
+ rdamping: float = 0,
108
+ eigenbasis_optimizer: LREOptimizerBase | None = None,
109
+ concat_params: bool = True,
110
+
111
+ inner: Chainable | None = None,
112
+ ):
113
+ defaults = locals().copy()
114
+ del defaults['self'], defaults['inner'], defaults['concat_params']
115
+
116
+ super().__init__(defaults, concat_params=concat_params, inner=inner)
117
+
118
+ @torch.no_grad
119
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
120
+ history_size = setting['history_size']
121
+ update_freq = setting['update_freq']
122
+
123
+ if 'history' not in state: state['history'] = deque(maxlen=history_size)
124
+ history = state['history']
125
+
126
+ t = tensor.clone().view(-1)
127
+ history.append(t)
128
+
129
+ step = state.get('step', 0)
130
+ state['step'] = step + 1
131
+
132
+ if step % update_freq == 0 :
133
+
134
+ # compute new factors
135
+ L = state.get("L", None)
136
+ U = state.get("U", None)
137
+
138
+ L_new, U_new = ggt_update(
139
+ history,
140
+ damping=setting["damping"],
141
+ rdamping=setting["rdamping"],
142
+ truncate=setting["truncate"],
143
+ eig_tol=setting["eig_tol"],
144
+ )
145
+
146
+ # reproject eigenbasis optimizer
147
+ eigenbasis_optimizer: LREOptimizerBase | None = setting["eigenbasis_optimizer"]
148
+ if eigenbasis_optimizer is not None:
149
+ if (L is not None) and (U is not None) and (L_new is not None) and (U_new is not None):
150
+ eigenbasis_state = state["eigenbasis_state"]
151
+ eigenbasis_optimizer.reproject(L_old=L, Q_old=U, L_new=L_new, Q_new=U_new, state=eigenbasis_state)
152
+
153
+
154
+ # store new factors
155
+ if L_new is not None: state["L"] = L_new
156
+ if U_new is not None: state["U"] = U_new
157
+
158
+
159
+ @torch.no_grad
160
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
161
+ g = tensor.view(-1)
162
+ U = state.get('U', None)
163
+
164
+ if U is None:
165
+ # fallback to element-wise preconditioning
166
+ history = torch.stack(tuple(state["history"]), 0)
167
+ g /= history.square().mean(0).sqrt().add(1e-8)
168
+ return g.view_as(tensor)
169
+
170
+ L = state['L']
171
+
172
+ # step with eigenbasis optimizer
173
+ eigenbasis_optimizer: LREOptimizerBase | None = setting["eigenbasis_optimizer"]
174
+ if eigenbasis_optimizer is not None:
175
+
176
+ if "eigenbasis_state" not in state: state["eigenbasis_state"] = {}
177
+ eigenbasis_state = state["eigenbasis_state"]
178
+
179
+ update = eigenbasis_optimizer.step(g, L=L, Q=U, state=eigenbasis_state)
180
+ return update.view_as(tensor)
181
+
182
+ # or just whiten
183
+ z = U.T @ g
184
+ update = (U * L.rsqrt()) @ z
185
+ return update.view_as(tensor)
186
+
@@ -1,21 +1,17 @@
1
+ from typing import Any
1
2
  import torch
2
3
 
3
- from ...core import Module, Target, Transform
4
+ from ...core import TensorTransform
4
5
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
5
6
 
6
7
 
7
- def lion_(tensors: TensorList, exp_avg_: TensorList, beta1, beta2,):
8
- """
9
- Lion update rule.
10
-
11
- Returns new tensors.
12
- """
8
+ def lion_(tensors: TensorList | Any, exp_avg_: TensorList | Any, beta1, beta2,):
13
9
  update = exp_avg_.lerp(tensors, 1-beta1).sign_()
14
10
  exp_avg_.lerp_(tensors, 1-beta2)
15
11
  return update
16
12
 
17
13
 
18
- class Lion(Transform):
14
+ class Lion(TensorTransform):
19
15
  """Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.
20
16
 
21
17
  Args:
@@ -25,11 +21,11 @@ class Lion(Transform):
25
21
 
26
22
  def __init__(self, beta1: float = 0.9, beta2: float = 0.99):
27
23
  defaults = dict(beta1=beta1, beta2=beta2)
28
- super().__init__(defaults, uses_grad=False)
24
+ super().__init__(defaults)
29
25
 
30
26
  @torch.no_grad
31
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
27
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
32
28
  beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
33
29
  exp_avg = unpack_states(states, tensors, 'ema', cls=TensorList)
34
- return lion_(TensorList(tensors),exp_avg,beta1,beta2)
30
+ return lion_(TensorList(tensors), exp_avg, beta1, beta2)
35
31