torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,138 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- from collections import ChainMap, defaultdict
3
- from collections.abc import Mapping, Sequence
4
- from typing import Any, overload, final
5
-
6
- import torch
7
-
8
- from .module import Module, Chainable, Vars
9
- from .transform import apply, Transform, Target
10
- from ..utils import TensorList, vec_to_tensors
11
-
12
- class Preconditioner(Transform):
13
- """Abstract class for a preconditioner."""
14
- def __init__(
15
- self,
16
- defaults: dict | None,
17
- uses_grad: bool,
18
- concat_params: bool = False,
19
- update_freq: int = 1,
20
- scale_first: bool = False,
21
- inner: Chainable | None = None,
22
- target: Target = "update",
23
- ):
24
- if defaults is None: defaults = {}
25
- defaults.update(dict(__update_freq=update_freq, __concat_params=concat_params, __scale_first=scale_first))
26
- super().__init__(defaults, uses_grad=uses_grad, target=target)
27
-
28
- if inner is not None:
29
- self.set_child('inner', inner)
30
-
31
- @abstractmethod
32
- def update(self, tensors: list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]):
33
- """updates the preconditioner with `tensors`, any internal state should be stored using `keys`"""
34
-
35
- @abstractmethod
36
- def apply(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> list[torch.Tensor]:
37
- """applies preconditioner to `tensors`, any internal state should be stored using `keys`"""
38
-
39
-
40
- def _tensor_wise_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
41
- step = self.global_state.get('__step', 0)
42
- states = [self.state[p] for p in params]
43
- settings = [self.settings[p] for p in params]
44
- global_settings = settings[0]
45
- update_freq = global_settings['__update_freq']
46
-
47
- scale_first = global_settings['__scale_first']
48
- scale_factor = 1
49
- if scale_first and step == 0:
50
- # initial step size guess from pytorch LBFGS
51
- scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
52
- scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
53
-
54
- # update preconditioner
55
- if step % update_freq == 0:
56
- self.update(tensors=tensors, params=params, grads=grads, states=states, settings=settings)
57
-
58
- # step with inner
59
- if 'inner' in self.children:
60
- tensors = apply(self.children['inner'], tensors=tensors, params=params, grads=grads, vars=vars)
61
-
62
- # apply preconditioner
63
- tensors = self.apply(tensors=tensors, params=params, grads=grads, states=states, settings=settings)
64
-
65
- # scale initial step, when preconditioner might not have been applied
66
- if scale_first and step == 0:
67
- torch._foreach_mul_(tensors, scale_factor)
68
-
69
- self.global_state['__step'] = step + 1
70
- return tensors
71
-
72
- def _concat_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
73
- step = self.global_state.get('__step', 0)
74
- tensors_vec = torch.cat([t.ravel() for t in tensors])
75
- params_vec = torch.cat([p.ravel() for p in params])
76
- grads_vec = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
77
-
78
- states = [self.state[params[0]]]
79
- settings = [self.settings[params[0]]]
80
- global_settings = settings[0]
81
- update_freq = global_settings['__update_freq']
82
-
83
- scale_first = global_settings['__scale_first']
84
- scale_factor = 1
85
- if scale_first and step == 0:
86
- # initial step size guess from pytorch LBFGS
87
- scale_factor = 1 / tensors_vec.abs().sum().clip(min=1)
88
- scale_factor = scale_factor.clip(min=torch.finfo(tensors_vec.dtype).eps)
89
-
90
- # update preconditioner
91
- if step % update_freq == 0:
92
- self.update(tensors=[tensors_vec], params=[params_vec], grads=grads_vec, states=states, settings=settings)
93
-
94
- # step with inner
95
- if 'inner' in self.children:
96
- tensors = apply(self.children['inner'], tensors=tensors, params=params, grads=grads, vars=vars)
97
- tensors_vec = torch.cat([t.ravel() for t in tensors]) # have to recat
98
-
99
- # apply preconditioner
100
- tensors_vec = self.apply(tensors=[tensors_vec], params=[params_vec], grads=grads_vec, states=states, settings=settings)[0]
101
-
102
- # scale initial step, when preconditioner might not have been applied
103
- if scale_first and step == 0:
104
- tensors_vec *= scale_factor
105
-
106
- tensors = vec_to_tensors(vec=tensors_vec, reference=tensors)
107
- self.global_state['__step'] = step + 1
108
- return tensors
109
-
110
- @torch.no_grad
111
- def transform(self, tensors, params, grads, vars):
112
- concat_params = self.settings[params[0]]['__concat_params']
113
- if concat_params: return self._concat_transform(tensors, params, grads, vars)
114
- return self._tensor_wise_transform(tensors, params, grads, vars)
115
-
116
- class TensorwisePreconditioner(Preconditioner, ABC):
117
- @abstractmethod
118
- def update_tensor(self, tensor: torch.Tensor, param:torch.Tensor, grad: torch.Tensor | None, state: dict[str, Any], settings: Mapping[str, Any]):
119
- """update preconditioner with `tensor`"""
120
-
121
- @abstractmethod
122
- def apply_tensor(self, tensor: torch.Tensor, param:torch.Tensor, grad: torch.Tensor | None, state: dict[str, Any], settings: Mapping[str, Any]) -> torch.Tensor:
123
- """apply preconditioner to `tensor`"""
124
-
125
- @final
126
- def update(self, tensors, params, grads, states, settings):
127
- if grads is None: grads = [None]*len(tensors)
128
- for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
129
- self.update_tensor(t, p, g, state, setting)
130
-
131
- @final
132
- def apply(self, tensors, params, grads, states, settings):
133
- preconditioned = []
134
- if grads is None: grads = [None]*len(tensors)
135
- for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
136
- preconditioned.append(self.apply_tensor(t, p, g, state, setting))
137
- return preconditioned
138
-
@@ -1,145 +0,0 @@
1
- import warnings
2
- from functools import partial
3
- from typing import Literal
4
- from collections.abc import Callable
5
- import torch
6
- import torchalgebras as ta
7
-
8
- from ...core import Chainable, apply, Module
9
- from ...utils import vec_to_tensors, TensorList
10
- from ...utils.derivatives import (
11
- hessian_list_to_mat,
12
- hessian_mat,
13
- jacobian_and_hessian_wrt,
14
- )
15
-
16
- class MaxItersReached(Exception): pass
17
- def tropical_lstsq(
18
- H: torch.Tensor,
19
- g: torch.Tensor,
20
- solver,
21
- maxiter,
22
- tol,
23
- algebra,
24
- verbose,
25
- ):
26
- """it can run on any algebra with add despite it saying tropical"""
27
- algebra = ta.get_algebra(algebra)
28
-
29
- x = torch.zeros_like(g, requires_grad=True)
30
- best_x = x.detach().clone()
31
- best_loss = float('inf')
32
- opt = solver([x])
33
-
34
- niter = 0
35
- def closure(backward=True):
36
- nonlocal niter, best_x, best_loss
37
- if niter == maxiter: raise MaxItersReached
38
- niter += 1
39
-
40
- g_hat = algebra.mm(H, x)
41
- loss = torch.nn.functional.mse_loss(g_hat, g)
42
- if loss < best_loss:
43
- best_x = x.detach().clone()
44
- best_loss = loss.detach()
45
-
46
- if backward:
47
- opt.zero_grad()
48
- loss.backward()
49
- return loss
50
-
51
- loss = None
52
- prev_loss = float('inf')
53
- for i in range(maxiter):
54
- try:
55
- loss = opt.step(closure)
56
- if loss == 0: break
57
- if tol is not None and prev_loss - loss < tol: break
58
- prev_loss = loss
59
- except MaxItersReached:
60
- break
61
-
62
- if verbose: print(f'{best_loss = } after {niter} iters')
63
- return best_x.detach()
64
-
65
- def tikhonov(H: torch.Tensor, reg: float, algebra: ta.Algebra = ta.TropicalSemiring()):
66
- if reg!=0:
67
- I = ta.AlgebraicTensor(torch.eye(H.size(-1), dtype=H.dtype, device=H.device), algebra)
68
- I = I * reg
69
- H = algebra.add(H, I.data)
70
- return H
71
-
72
-
73
- class AlgebraicNewton(Module):
74
- """newton in other algebras, not that it works."""
75
- def __init__(
76
- self,
77
- reg: float | None = None,
78
- hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
79
- vectorize: bool = True,
80
- solver=lambda p: torch.optim.LBFGS(p, line_search_fn='strong_wolfe'),
81
- maxiter=1000,
82
- tol: float | None = 1e-10,
83
- algebra: ta.Algebra | str = 'tropical max',
84
- verbose: bool = False,
85
- inner: Chainable | None = None,
86
- ):
87
- defaults = dict(reg=reg, hessian_method=hessian_method, vectorize=vectorize)
88
- super().__init__(defaults)
89
-
90
- self.algebra = ta.get_algebra(algebra)
91
- self.lstsq_args:dict = dict(solver=solver, maxiter=maxiter, tol=tol, algebra=algebra, verbose=verbose)
92
-
93
- if inner is not None:
94
- self.set_child('inner', inner)
95
-
96
- @torch.no_grad
97
- def step(self, vars):
98
- params = TensorList(vars.params)
99
- closure = vars.closure
100
- if closure is None: raise RuntimeError('NewtonCG requires closure')
101
-
102
- settings = self.settings[params[0]]
103
- reg = settings['reg']
104
- hessian_method = settings['hessian_method']
105
- vectorize = settings['vectorize']
106
-
107
- # ------------------------ calculate grad and hessian ------------------------ #
108
- if hessian_method == 'autograd':
109
- with torch.enable_grad():
110
- loss = vars.loss = vars.loss_approx = closure(False)
111
- g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
112
- g_list = [t[0] for t in g_list] # remove leading dim from loss
113
- vars.grad = g_list
114
- H = hessian_list_to_mat(H_list)
115
-
116
- elif hessian_method in ('func', 'autograd.functional'):
117
- strat = 'forward-mode' if vectorize else 'reverse-mode'
118
- with torch.enable_grad():
119
- g_list = vars.get_grad(retain_graph=True)
120
- H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
121
- method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
122
-
123
- else:
124
- raise ValueError(hessian_method)
125
-
126
- # -------------------------------- inner step -------------------------------- #
127
- if 'inner' in self.children:
128
- g_list = apply(self.children['inner'], list(g_list), params=params, grads=list(g_list), vars=vars)
129
- g = torch.cat([t.view(-1) for t in g_list])
130
-
131
- # ------------------------------- regulazition ------------------------------- #
132
- if reg is not None: H = tikhonov(H, reg)
133
-
134
- # ----------------------------------- solve ---------------------------------- #
135
- tropical_update = tropical_lstsq(H, g, **self.lstsq_args)
136
- # what now? w - u is not defined, it is defined for max version if u < w
137
- # w = params.to_vec()
138
- # w_hat = self.algebra.sub(w, tropical_update)
139
- # update = w_hat - w
140
- # no
141
- # it makes sense to solve tropical system and sub normally
142
- # the only thing is that tropical system can have no solutions
143
-
144
- vars.update = vec_to_tensors(tropical_update, params)
145
- return vars
@@ -1,290 +0,0 @@
1
- from operator import itemgetter
2
-
3
- import torch
4
-
5
- from ...core import Chainable, Transform, apply
6
- from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
-
8
- @torch.no_grad
9
- def update_soap_covariances_(
10
- grad: torch.Tensor,
11
- GGs_: list[torch.Tensor | None],
12
- beta: float | None,
13
- ):
14
- for i, GG in enumerate(GGs_):
15
- if GG is None: continue
16
-
17
- axes = list(range(i)) + list(range(i + 1, grad.ndim)) # this works fine with 1d params
18
- if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
19
- else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
20
-
21
- @torch.no_grad
22
- def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
23
- """
24
- Projects the gradient to the eigenbases of the preconditioner.
25
- """
26
- for mat in Q:
27
- if mat is None: continue
28
- if len(mat) > 0:
29
- tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
30
- else:
31
- # I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
32
- permute_order = list(range(1, len(tensors.shape))) + [0]
33
- tensors = tensors.permute(permute_order)
34
-
35
- return tensors
36
-
37
- @torch.no_grad
38
- def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
39
- """
40
- Projects the gradient back to the original space.
41
- """
42
- for mat in Q:
43
- if mat is None: continue
44
- if len(mat) > 0:
45
- tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
46
- else:
47
- permute_order = list(range(1, len(tensors.shape))) + [0]
48
- tensors = tensors.permute(permute_order)
49
-
50
- return tensors
51
-
52
- # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
53
- @torch.no_grad
54
- def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
55
- """
56
- Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
57
- """
58
- matrix = []
59
- float_data = False
60
- original_type = original_device = None
61
- for m in mat:
62
- if m is None: continue
63
- if len(m) == 0:
64
- matrix.append([])
65
- continue
66
- if m.dtype != torch.float:
67
- original_type = m.dtype
68
- original_device = m.device
69
- matrix.append(m.float())
70
- else:
71
- float_data = True
72
- matrix.append(m)
73
-
74
- final = []
75
- for m in matrix:
76
- if len(m) == 0:
77
- final.append([])
78
- continue
79
- try:
80
- _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
81
- except Exception:
82
- _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
83
- Q = Q.to(m.dtype)
84
- Q = torch.flip(Q, [1])
85
-
86
- if not float_data:
87
- Q = Q.to(original_device).type(original_type)
88
- final.append(Q)
89
- return final
90
-
91
- # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
92
- @torch.no_grad
93
- def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
94
- """
95
- Computes the eigenbases of the preconditioner using one round of power iteration
96
- followed by torch.linalg.qr decomposition.
97
- """
98
- matrix = []
99
- orth_matrix = []
100
- float_data = False
101
- original_type = original_device = None
102
- for m,o in zip(GG, Q_list):
103
- if m is None: continue
104
- assert o is not None
105
-
106
- if len(m) == 0:
107
- matrix.append([])
108
- orth_matrix.append([])
109
- continue
110
- if m.data.dtype != torch.float:
111
- original_type = m.data.dtype
112
- original_device = m.data.device
113
- matrix.append(m.data.float())
114
- orth_matrix.append(o.data.float())
115
- else:
116
- float_data = True
117
- matrix.append(m.data.float())
118
- orth_matrix.append(o.data.float())
119
-
120
- final = []
121
- for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
122
- if len(m)==0:
123
- final.append([])
124
- continue
125
- est_eig = torch.diag(o.T @ m @ o)
126
- sort_idx = torch.argsort(est_eig, descending=True)
127
- exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
128
- o = o[:,sort_idx]
129
- power_iter = m @ o
130
- Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
131
-
132
- if not float_data:
133
- Q = Q.to(original_device).type(original_type)
134
- final.append(Q)
135
-
136
- return final, exp_avg_sq
137
-
138
- class SOAPY(Transform):
139
- """SOAP but uses scaled gradient differences
140
-
141
- new args
142
-
143
- scale by s whether to scale gradient differences by parameter differences
144
-
145
- y_to_ema2 whether to use gradient differences for exponential moving average too
146
- """
147
- def __init__(
148
- self,
149
- beta1: float = 0.95,
150
- beta2: float = 0.95,
151
- shampoo_beta: float | None = 0.95,
152
- precond_freq: int = 10,
153
- merge_small: bool = True,
154
- max_dim: int = 2_000,
155
- precondition_1d: bool = True,
156
- eps: float = 1e-8,
157
- decay: float | None = None,
158
- alpha: float = 1,
159
- bias_correction: bool = True,
160
- scale_by_s: bool = True,
161
- y_to_ema2: bool = False,
162
- ):
163
- defaults = dict(
164
- beta1=beta1,
165
- beta2=beta2,
166
- shampoo_beta=shampoo_beta,
167
- precond_freq=precond_freq,
168
- merge_small=merge_small,
169
- max_dim=max_dim,
170
- precondition_1d=precondition_1d,
171
- eps=eps,
172
- decay=decay,
173
- bias_correction=bias_correction,
174
- alpha=alpha,
175
- scale_by_s=scale_by_s,
176
- y_to_ema2=y_to_ema2,
177
- )
178
- super().__init__(defaults, uses_grad=False)
179
-
180
- @torch.no_grad
181
- def transform(self, tensors, params, grads, vars):
182
- updates = []
183
- # update preconditioners
184
- for i,(p,t) in enumerate(zip(params, tensors)):
185
- state = self.state[p]
186
- settings = self.settings[p]
187
- beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
188
- 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(settings)
189
- scale_by_s = settings['scale_by_s']
190
- y_to_ema2 = settings['y_to_ema2']
191
-
192
- if merge_small:
193
- t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
194
-
195
- if 'g_prev' not in state:
196
- state['p_prev'] = p.clone()
197
- state['g_prev'] = t.clone()
198
- updates.append(tensors[i].clip(-0.1,0.1))
199
- continue
200
-
201
- p_prev = state['p_prev']
202
- g_prev = state['g_prev']
203
- s = p - p_prev
204
- y = t - g_prev
205
- if scale_by_s: y /= torch.linalg.norm(s).clip(min=1e-8) # pylint:disable=not-callable
206
-
207
- state['p_prev'].copy_(p)
208
- state['g_prev'].copy_(t)
209
-
210
- # initialize state on 1st step
211
- if 'GG' not in state:
212
- state["exp_avg"] = torch.zeros_like(t)
213
- if y_to_ema2: state["exp_avg_sq"] = torch.ones_like(t)
214
- else: state["exp_avg_sq"] = torch.zeros_like(t)
215
-
216
- if not precondition_1d and t.ndim <= 1:
217
- state['GG'] = []
218
-
219
- else:
220
- state['GG'] = [torch.zeros(sh, sh, dtype=t.dtype, device=t.device) if 1<sh<max_dim else None for sh in t.shape]
221
-
222
- # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
223
- if len([i is not None for i in state['GG']]) == 0:
224
- state['GG'] = None
225
-
226
- if state['GG'] is not None:
227
- update_soap_covariances_(y, GGs_=state['GG'], beta=shampoo_beta)
228
- state['Q'] = get_orthogonal_matrix(state['GG'])
229
-
230
- state['step'] = 0
231
- updates.append(tensors[i].clip(-0.1,0.1))
232
- continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
233
- # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
234
-
235
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
236
- # i.e. projecting to the eigenbases of matrices in state['GG']
237
- z_projected = None
238
- if state['GG'] is not None:
239
- if y_to_ema2: z_projected = project(y, state['Q'])
240
- else: z_projected = project(t, state['Q'])
241
-
242
- # exponential moving averages
243
- # this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
244
- exp_avg: torch.Tensor = state["exp_avg"]
245
- exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
246
-
247
- exp_avg.lerp_(t, 1-beta1)
248
-
249
- if z_projected is None:
250
- if y_to_ema2: exp_avg_sq.mul_(beta2).addcmul_(y, y, value=1-beta2)
251
- else: exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
252
- else:
253
- exp_avg_sq.mul_(beta2).addcmul_(z_projected, z_projected, value=1-beta2)
254
-
255
- # project exponential moving averages if they are accumulated unprojected
256
- exp_avg_projected = exp_avg
257
- if z_projected is not None:
258
- exp_avg_projected = project(exp_avg, state['Q'])
259
-
260
- exp_avg_sq_projected = exp_avg_sq
261
-
262
- denom = exp_avg_sq_projected.sqrt().add_(eps)
263
- # print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
264
-
265
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
266
- # to the original space
267
- update = exp_avg_projected / denom
268
- if z_projected is not None:
269
- update = project_back(update, state["Q"])
270
-
271
- if settings['bias_correction']:
272
- bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
273
- bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
274
- update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
275
- elif alpha is not None:
276
- update *= alpha
277
-
278
- if merge_small:
279
- update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
280
-
281
- updates.append(update)
282
- state["step"] += 1
283
-
284
- # Update is done after the gradient step to avoid using current gradients in the projection.
285
- if state['GG'] is not None:
286
- update_soap_covariances_(y, state['GG'], shampoo_beta)
287
- if state['step'] % settings['precond_freq'] == 0:
288
- state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
289
-
290
- return updates