torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,90 +0,0 @@
1
- from operator import itemgetter
2
-
3
- import torch
4
-
5
- from ..line_search import LineSearchBase
6
-
7
-
8
- class AdaptiveStepSize(LineSearchBase):
9
- """Basic first order step size adaptation method. Re-evaluates the function after stepping, if value decreased sufficiently,
10
- step size is increased. If value increased, step size is decreased.
11
-
12
- .. note::
13
- This works well in some cases, but it is often prone to collapsing.
14
- For a more robust alternative use :code:`tz.m.AdaptiveBacktracking`.
15
-
16
- Args:
17
- nplus (float, optional): multiplier to step size on successful steps. Defaults to 1.5.
18
- nminus (float, optional): multiplier to step size on unsuccessful steps. Defaults to 0.75.
19
- c (float, optional): descent condition. Defaults to 1e-4.
20
- init (float, optional): initial step size. Defaults to 1.
21
- backtrack (bool, optional): whether to undo the step if value increased. Defaults to True.
22
- adaptive (bool, optional):
23
- If enabled, when multiple consecutive steps have been successful or unsuccessful,
24
- the corresponding multipliers are increased, otherwise they are reset. Defaults to True.
25
-
26
-
27
- Examples:
28
- Adagrad with trust region:
29
-
30
- .. code-block:: python
31
-
32
- opt = tz.Modular(
33
- model.parameters(),
34
- tz.m.Adagrad(),
35
- tz.m.TrustRegion()
36
- )
37
-
38
- """
39
- def __init__(self, nplus: float=1.5, nminus: float=0.75, c: float=1e-4, init: float = 1, backtrack: bool = True, adaptive: bool = True):
40
- defaults = dict(nplus=nplus, nminus=nminus, c=c, init=init, backtrack=backtrack, adaptive=adaptive)
41
- super().__init__(defaults)
42
-
43
- @torch.no_grad
44
- def search(self, update, var):
45
-
46
- nplus, nminus, c, init, backtrack, adaptive = itemgetter('nplus','nminus','c','init','backtrack', 'adaptive')(self.settings[var.params[0]])
47
- step_size = self.global_state.setdefault('step_size', init)
48
- previous_success = self.global_state.setdefault('previous_success', False)
49
- nplus_mul = self.global_state.setdefault('nplus_mul', 1)
50
- nminus_mul = self.global_state.setdefault('nminus_mul', 1)
51
-
52
-
53
- f_0 = self.evaluate_step_size(0, var, backward=False)
54
-
55
- # directional derivative (0 if c = 0 because it is not needed)
56
- if c == 0: d = 0
57
- else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), update))
58
-
59
- # test step size
60
- sufficient_f = f_0 + c * step_size * min(d, 0) # pyright:ignore[reportArgumentType]
61
-
62
- f_1 = self.evaluate_step_size(step_size, var, backward=False)
63
-
64
- proposed = step_size
65
-
66
- # very good step
67
- if f_1 < sufficient_f:
68
- self.global_state['step_size'] *= nplus * nplus_mul
69
-
70
- # two very good steps in a row - increase nplus_mul
71
- if adaptive:
72
- if previous_success: self.global_state['nplus_mul'] *= nplus
73
- else: self.global_state['nplus_mul'] = 1
74
-
75
- # acceptable step step
76
- #elif f_1 <= f_0: pass
77
-
78
- # bad step
79
- if f_1 >= f_0:
80
- self.global_state['step_size'] *= nminus * nminus_mul
81
-
82
- # two bad steps in a row - decrease nminus_mul
83
- if adaptive:
84
- if previous_success: self.global_state['nminus_mul'] *= nminus
85
- else: self.global_state['nminus_mul'] = 1
86
-
87
- if backtrack: proposed = 0
88
- else: proposed *= nminus * nminus_mul
89
-
90
- return proposed
@@ -1,177 +0,0 @@
1
- from operator import itemgetter
2
-
3
- import torch
4
-
5
- from ...core import Chainable, Transform
6
- from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
- from ..optimizers.soap import (
8
- get_orthogonal_matrix,
9
- get_orthogonal_matrix_QR,
10
- project,
11
- project_back,
12
- )
13
-
14
-
15
- @torch.no_grad
16
- def update_adasoap_covariances_(
17
- grad: torch.Tensor,
18
- GGs_: list[torch.Tensor | None],
19
- GG_sqs: list[torch.Tensor | None],
20
- beta: float | None,
21
- precond_beta: float | None,
22
- ):
23
- for i, (GG, GG_sq) in enumerate(zip(GGs_, GG_sqs)):
24
- if GG is None: continue
25
- assert GG_sq is not None
26
-
27
- if precond_beta is None: GG_sq.addcmul_(GG, GG)
28
- else: GG_sq.mul_(precond_beta).addcmul_(GG, GG, value=1-precond_beta)
29
-
30
- axes = list(range(i)) + list(range(i + 1, grad.ndim)) # this works fine with 1d params
31
- if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
32
- else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
33
-
34
-
35
- class AdaSOAP(Transform):
36
- """SOAP with diagonally preconditioned GG^Ts.
37
-
38
- .. warning::
39
- Experimental.
40
-
41
- precond_beta - beta for GG^T squares
42
-
43
- Verdict: It works, but it is about the same performance as Adam, but maybe more tuning potential?
44
- """
45
- def __init__(
46
- self,
47
- beta1: float = 0.95,
48
- beta2: float = 0.95,
49
- shampoo_beta: float | None = 0.95,
50
- precond_beta: float | None = 0.95,
51
- precond_freq: int = 10,
52
- merge_small: bool = True,
53
- max_dim: int = 2_000,
54
- precondition_1d: bool = True,
55
- eps: float = 1e-8,
56
- decay: float | None = None,
57
- alpha: float = 1,
58
- unprojected_exp_avg: bool = True,
59
- bias_correction: bool = True,
60
- ):
61
- defaults = dict(
62
- beta1=beta1,
63
- beta2=beta2,
64
- shampoo_beta=shampoo_beta,
65
- precond_beta=precond_beta,
66
- precond_freq=precond_freq,
67
- merge_small=merge_small,
68
- max_dim=max_dim,
69
- precondition_1d=precondition_1d,
70
- eps=eps,
71
- decay=decay,
72
- unprojected_exp_avg=unprojected_exp_avg,
73
- bias_correction=bias_correction,
74
- alpha=alpha,
75
- )
76
- super().__init__(defaults, uses_grad=False)
77
-
78
- @torch.no_grad
79
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
80
- updates = []
81
- # update preconditioners
82
- for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
83
-
84
- beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, unprojected_exp_avg,alpha = itemgetter(
85
- 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(setting)
86
- precond_beta = setting['precond_beta']
87
-
88
- if merge_small:
89
- t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
90
-
91
- # initialize state on 1st step
92
- if 'GG' not in state:
93
- state["exp_avg"] = torch.zeros_like(t)
94
- state["exp_avg_sq"] = torch.zeros_like(t)
95
-
96
- if not precondition_1d and t.ndim <= 1:
97
- state['GG'] = []
98
- state['GG_sq'] = []
99
-
100
- else:
101
- state['GG'] = [torch.zeros(s, s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
102
- state['GG_sq'] = [torch.zeros(s, s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
103
-
104
- # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
105
- if len([i is not None for i in state['GG']]) == 0:
106
- state['GG'] = None
107
- state['GG_sq'] = None
108
-
109
- if state['GG'] is not None:
110
- assert state['GG_sq'] is not None
111
- update_adasoap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
112
- GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
113
- state['Q'] = get_orthogonal_matrix(GG_precond)
114
-
115
- state['step'] = 0
116
- updates.append(tensors[i].clip(-0.1,0.1))
117
- continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
118
- # that can mess with other modules scaling
119
-
120
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
121
- # i.e. projecting to the eigenbases of matrices in state['GG']
122
- t_projected = None
123
- if state['GG'] is not None:
124
- t_projected = project(t, state['Q'])
125
-
126
- # exponential moving averages
127
- # this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
128
- exp_avg: torch.Tensor = state["exp_avg"]
129
- exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
130
-
131
- if unprojected_exp_avg or t_projected is None:
132
- exp_avg.lerp_(t, 1-beta1)
133
- else:
134
- exp_avg.lerp_(t_projected, 1-beta1)
135
-
136
- if t_projected is None:
137
- exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
138
- else:
139
- exp_avg_sq.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
140
-
141
- # project exponential moving averages if they are accumulated unprojected
142
- exp_avg_projected = exp_avg
143
- if unprojected_exp_avg and t_projected is not None:
144
- exp_avg_projected = project(exp_avg, state['Q'])
145
-
146
- exp_avg_sq_projected = exp_avg_sq
147
-
148
- denom = exp_avg_sq_projected.sqrt().add_(eps)
149
- # print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
150
-
151
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
152
- # to the original space
153
- update = exp_avg_projected / denom
154
- if t_projected is not None:
155
- update = project_back(update, state["Q"])
156
-
157
- if setting['bias_correction']:
158
- bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
159
- bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
160
- update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
161
- elif alpha is not None:
162
- update *= alpha
163
-
164
- if merge_small:
165
- update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
166
-
167
- updates.append(update)
168
- state["step"] += 1
169
-
170
- # Update is done after the gradient step to avoid using current gradients in the projection.
171
- if state['GG'] is not None:
172
- update_adasoap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
173
- GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
174
- if state['step'] % setting['precond_freq'] == 0:
175
- state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, GG_precond, state['Q'])
176
-
177
- return updates
@@ -1,214 +0,0 @@
1
- """A bunch of useless modules that I hate and that didn't work"""
2
- import torch
3
-
4
- from ...core import Chainable, Transform, apply_transform
5
- from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
6
-
7
-
8
- class CosineStepSize(Transform):
9
- """Adaptive step size based on cosine similarity
10
-
11
- VERDICT: Useless. This is too unstable.
12
-
13
- Args:
14
- scale (float, optional): cosine similarity multiplier. Defaults to 0.95.
15
- init (float, optional): initial step size. Defaults to 1.
16
- eps (float, optional): epsilon for division stability. Defaults to 1e-12.
17
- target_cossim (float, optional): cosine similarity needs to be above this to increase step size. Defaults to 1e-8.
18
- inner (Chainable | None, optional):
19
- inner modules applied after calculating cosine similarity and before step size correction. Defaults to None.
20
- """
21
- def __init__(self, scale:float = 0.95, init:float=1, eps:float=1e-12, inner:Chainable | None = None):
22
- defaults = dict(scale=scale, init=init, eps=eps)
23
- super().__init__(defaults, uses_grad=False)
24
- if inner is not None: self.set_child('inner', inner)
25
-
26
- @torch.no_grad
27
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
28
- scale, init = unpack_dicts(settings, 'scale', 'init', cls=NumberList)
29
- unpack_states(states, tensors, 'alpha', init=init, cls=NumberList) # initializes alpha to init
30
- eps = settings[0]['eps']
31
-
32
- tensors = as_tensorlist(tensors)
33
- prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
34
-
35
- tensors_norm = tensors.global_vector_norm()
36
- cos_sim = (tensors.dot(prev) / (tensors_norm * prev.global_vector_norm()).clip(min=eps)).item()
37
-
38
- if 'inner' in self.children:
39
- tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
40
-
41
- new_alpha = []
42
- for s, sc in zip(states, scale):
43
- s['alpha'] *= 1 + cos_sim * sc
44
- new_alpha.append(s['alpha'])
45
-
46
- tensors.mul_(new_alpha)
47
- prev.copy_(tensors)
48
-
49
- return tensors
50
-
51
-
52
-
53
- class CosineDebounce(Transform):
54
- """Debouncing when cosine similarity is less than 0.
55
-
56
- VERDICT: Useless. This doesn't help at all.
57
-
58
- Args:
59
- scale (float, optional): cosine similarity multiplier. Defaults to 0.95.
60
- eps (float, optional): epsilon for division stability. Defaults to 1e-12.
61
- inner (Chainable | None, optional):
62
- inner modules applied after calculating cosine similarity and before debouncing correction. Defaults to None.
63
- """
64
- def __init__(self, scale:float = 0.95, eps:float=1e-12, damping:float=0.95, inner:Chainable | None = None):
65
- defaults = dict(scale=scale, eps=eps, damping=damping)
66
- super().__init__(defaults, uses_grad=False)
67
- if inner is not None: self.set_child('inner', inner)
68
-
69
- @torch.no_grad
70
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
71
- scale, damping = unpack_dicts(settings, 'scale', 'damping', cls=NumberList)
72
- eps = settings[0]['eps']
73
-
74
- tensors = as_tensorlist(tensors)
75
- prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList).mul_(damping)
76
-
77
- tensors_norm = tensors.global_vector_norm()
78
- cos_sim = (tensors.dot(prev) / (tensors_norm * prev.global_vector_norm()).clip(min=eps)).item()
79
-
80
- if 'inner' in self.children:
81
- tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
82
-
83
- if cos_sim < -eps:
84
- undo = prev.neg().mul_(-cos_sim * scale)
85
- comb = prev.graft(tensors).add_(tensors).graft_(prev).mul_(-cos_sim*scale)
86
- tensors = undo.add_(comb)
87
-
88
- prev.copy_(tensors)
89
- return tensors
90
-
91
-
92
-
93
- class CosineMomentum(Transform):
94
- """Beta depends on cosine similarity. At cossim=1, beta is 0. At cossim=-1, beta is 2^power. This basically removes oscillations.
95
-
96
- VERDICT: Useless. Worse than all other momentums.
97
-
98
- Args:
99
- scale (float, optional): cosine similarity multiplier. Defaults to 1.
100
- nesterov (float, optional): whether to use nesterov momentum. Defaults to False.
101
- power (float, optional): power for beta. Defaults to 1.
102
- eps (float, optional): epsilon for division stability. Defaults to 1e-12.
103
- inner (Chainable | None, optional):
104
- inner modules applied after calculating cosine similarity and before updating exponential moving average. Defaults to None.
105
- """
106
- def __init__(self, scale:float = 1, nesterov: bool = False, power: float = 1, eps:float=1e-12, inner:Chainable | None = None):
107
- defaults = dict(scale=scale, eps=eps, nesterov=nesterov, power=power)
108
- super().__init__(defaults, uses_grad=False)
109
- if inner is not None: self.set_child('inner', inner)
110
-
111
- @torch.no_grad
112
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
113
- scale, power = unpack_dicts(settings, 'scale', 'power', cls=NumberList)
114
- eps = settings[0]['eps']
115
- nesterov = settings[0]['nesterov']
116
- exp_avg = unpack_states(states, tensors, 'exp_avg', cls=TensorList)
117
-
118
- tensors = as_tensorlist(tensors)
119
-
120
- tensors_norm = tensors.global_vector_norm()
121
- cos_sim = (tensors.dot(exp_avg) / (tensors_norm * exp_avg.global_vector_norm()).clip(min=eps)).item()
122
-
123
- if 'inner' in self.children:
124
- tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
125
-
126
- beta = (1 - (cos_sim*scale)) ** power
127
- if nesterov:
128
- exp_avg.add_(tensors.mul(beta))
129
- return tensors.add_(exp_avg)
130
- else:
131
- exp_avg.add_(tensors.mul_(beta))
132
- return exp_avg.clone()
133
-
134
-
135
- class AdaptiveDifference(Transform):
136
- """VERDICT: Useless. Doesn't help (sort of to be expected)."""
137
- def __init__(self, inner:Chainable | None = None):
138
- defaults = dict()
139
- super().__init__(defaults, uses_grad=False)
140
- if inner is not None: self.set_child('inner', inner)
141
-
142
- @torch.no_grad
143
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
144
- tensors = as_tensorlist(tensors)
145
- prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
146
-
147
- diff = tensors - prev.graft_(tensors)
148
- prev.copy_(tensors)
149
-
150
- if 'inner' in self.children:
151
- tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
152
-
153
- tensors.add_(diff.graft_(tensors))
154
-
155
- return tensors
156
-
157
- class AdaptiveDifferenceEMA(Transform):
158
- """VERDICT: better than non-EMA but still useless."""
159
- def __init__(self, beta=0.99, inner:Chainable | None = None):
160
- defaults = dict(beta=beta)
161
- super().__init__(defaults, uses_grad=False)
162
- if inner is not None: self.set_child('inner', inner)
163
-
164
- @torch.no_grad
165
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
166
- tensors = as_tensorlist(tensors)
167
- beta = unpack_dicts(settings, 'beta', cls=NumberList)
168
- prev, diff_exp_avg = unpack_states(states, tensors, 'prev', 'diff_exp_avg', init=[tensors,torch.zeros_like], cls=TensorList)
169
-
170
- diff = (tensors - prev.graft_(tensors)).graft_(tensors)
171
- diff_exp_avg.lerp_(diff, 1-beta)
172
- prev.copy_(tensors)
173
-
174
- if 'inner' in self.children:
175
- tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
176
-
177
- tensors.add_(diff_exp_avg.graft(tensors))
178
-
179
- return tensors
180
-
181
-
182
- class ScaledAdaptiveDifference(Transform):
183
- """VERDICT: Useless and doesn't help."""
184
- def __init__(self, scale=0.95, damping:float=0.99, inner:Chainable | None = None):
185
- defaults = dict(scale=scale, damping=damping)
186
- super().__init__(defaults, uses_grad=False)
187
- if inner is not None: self.set_child('inner', inner)
188
-
189
- @torch.no_grad
190
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
191
- tensors = as_tensorlist(tensors)
192
- scale, damping = unpack_dicts(settings, 'scale', 'damping', cls=NumberList)
193
- prev_tensors, prev_update = unpack_states(states, tensors, 'prev', 'prev_update', init=[tensors,tensors], cls=TensorList)
194
-
195
- cos_sim = (tensors.dot(prev_update) / (tensors.global_vector_norm() * prev_update.global_vector_norm()).clip(min=1e-10)).item()
196
-
197
- if 'inner' in self.children:
198
- tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
199
-
200
- if cos_sim > 0:
201
- tensors.add_(prev_tensors*(cos_sim*scale))
202
-
203
- else:
204
- undo = prev_tensors.neg().mul_(-cos_sim*scale)
205
- comb = prev_tensors.graft(tensors).add_(tensors).graft_(prev_tensors).mul_(-cos_sim*scale)
206
- tensors = undo.add_(comb).graft_((tensors-prev_tensors).mul_(damping))
207
-
208
- diff = tensors - prev_tensors.graft_(tensors)
209
- prev_tensors.copy_(tensors)
210
- diff.graft_(tensors)
211
- tensors.add_(diff)
212
- prev_update.copy_(tensors)
213
-
214
- return tensors
@@ -1,97 +0,0 @@
1
- import torch
2
-
3
- from ...core import Transform
4
- from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
5
-
6
-
7
- def signed_cbrt(x: TensorList) -> TensorList:
8
- return x.sign() * x.abs().pow(1/3)
9
-
10
- def cubic_adam_(
11
- tensors: TensorList,
12
- exp_avg_: TensorList,
13
- exp_avg_sq_: TensorList,
14
- exp_avg_cu_: TensorList,
15
- alpha: float | NumberList,
16
- beta1: float | NumberList,
17
- beta2: float | NumberList,
18
- beta3: float | NumberList,
19
- eps: float | NumberList,
20
- debiased: bool,
21
- step: int,
22
- ):
23
- exp_avg_.lerp_(tensors, 1-beta1)
24
- exp_avg_sq_.lerp_(tensors**2, 1-beta2)
25
- exp_avg_cu_.lerp_(tensors**3, 1-beta3)
26
-
27
- if debiased:
28
- m1 = exp_avg_ / (1 - beta1 ** step)
29
- m2 = exp_avg_sq_ / (1 - beta2 ** step)
30
- m3 = exp_avg_cu_ / (1 - beta3 ** step)
31
- else:
32
- m1, m2, m3 = exp_avg_, exp_avg_sq_, exp_avg_cu_
33
-
34
- # adam minimizes ax^2 + bx
35
- # we are going to minimize ax^3 + bx^2 + cx
36
- A = signed_cbrt(m3)
37
- B = m2.sqrt()
38
- C = m1
39
- discriminant = B.pow(2) - 4 * A * C
40
-
41
- denom = 2 * A
42
- root = discriminant.clamp(min=0).sqrt_()
43
-
44
- x0 = (-B + root) / (denom + eps)
45
- x1 = (-B - root) / (denom + eps)
46
-
47
- f0 = (A/3)*x0**3 + (B/2)*x0**2 + C*x0
48
- f1 = (A/3)*x1**3 + (B/2)*x1**2 + C*x1
49
-
50
- x_star = x0.where(f0 < f1, x1)
51
-
52
- adam = -C / (B + eps)
53
- x_star = adam.where(discriminant < 0, x_star)
54
-
55
- return x_star.mul_(-alpha)
56
-
57
- class CubicAdam(Transform):
58
- """Adam which has 3rd momentum and minimizes a cubic polynomial.
59
-
60
- VERDICT: can outperform Adam very slightly. Usually very similar performance.
61
-
62
- .. warning::
63
- Experimental.
64
-
65
- """
66
- def __init__(
67
- self,
68
- beta1: float = 0.9,
69
- beta2: float = 0.99,
70
- beta3: float = 0.99,
71
- eps: float = 1e-8,
72
- debiased:bool=True,
73
- alpha: float = 1.,
74
- ):
75
- defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps,debiased=debiased,alpha=alpha)
76
- super().__init__(defaults, uses_grad=False)
77
-
78
- @torch.no_grad
79
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
80
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
81
-
82
- beta1,beta2,beta3,eps,alpha=unpack_dicts(settings, 'beta1','beta2','beta3','eps','alpha', cls=NumberList)
83
- exp_avg, exp_avg_sq, exp_avg_cu = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'exp_avg_cu', cls=TensorList)
84
-
85
- return cubic_adam_(
86
- tensors=TensorList(tensors),
87
- exp_avg_=exp_avg,
88
- exp_avg_sq_=exp_avg_sq,
89
- exp_avg_cu_=exp_avg_cu,
90
- alpha=alpha,
91
- beta1=beta1,
92
- beta2=beta2,
93
- beta3=beta3,
94
- eps=eps,
95
- debiased=settings[0]['debiased'],
96
- step=step,
97
- )