torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,14 @@
1
1
  from operator import itemgetter
2
+ from typing import Literal
2
3
 
3
4
  import torch
4
- from typing import Literal
5
- from ...core import Chainable, Transform, apply
5
+
6
+ from ...core import Chainable, Transform
6
7
  from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
8
+ from ..optimizers.soap import project, project_back, get_orthogonal_matrix, get_orthogonal_matrix_QR
7
9
 
8
10
  @torch.no_grad
9
- def update_soap_covariances_(
11
+ def update_absoap_covariances_(
10
12
  g1: torch.Tensor,
11
13
  g2: torch.Tensor,
12
14
  GGs_: list[torch.Tensor | None],
@@ -19,138 +21,36 @@ def update_soap_covariances_(
19
21
  if beta is None: GG.add_(torch.tensordot(g1, g2, (axes, axes))) # pyright:ignore[reportArgumentType]
20
22
  else: GG.lerp_(torch.tensordot(g1, g2, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
21
23
 
22
- @torch.no_grad
23
- def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
24
- """
25
- Projects the gradient to the eigenbases of the preconditioner.
26
- """
27
- for mat in Q:
28
- if mat is None: continue
29
- if len(mat) > 0:
30
- tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
31
- else:
32
- # I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
33
- permute_order = list(range(1, len(tensors.shape))) + [0]
34
- tensors = tensors.permute(permute_order)
35
-
36
- return tensors
37
24
 
38
- @torch.no_grad
39
- def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
40
- """
41
- Projects the gradient back to the original space.
42
- """
43
- for mat in Q:
44
- if mat is None: continue
45
- if len(mat) > 0:
46
- tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
47
- else:
48
- permute_order = list(range(1, len(tensors.shape))) + [0]
49
- tensors = tensors.permute(permute_order)
50
-
51
- return tensors
52
-
53
- # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
54
- @torch.no_grad
55
- def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
56
- """
57
- Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
58
- """
59
- matrix = []
60
- float_data = False
61
- original_type = original_device = None
62
- for m in mat:
63
- if m is None: continue
64
- if len(m) == 0:
65
- matrix.append([])
66
- continue
67
- if m.dtype != torch.float:
68
- original_type = m.dtype
69
- original_device = m.device
70
- matrix.append(m.float())
71
- else:
72
- float_data = True
73
- matrix.append(m)
74
-
75
- final = []
76
- for m in matrix:
77
- if len(m) == 0:
78
- final.append([])
79
- continue
80
- try:
81
- _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
82
- except Exception:
83
- _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
84
- Q = Q.to(m.dtype)
85
- Q = torch.flip(Q, [1])
86
-
87
- if not float_data:
88
- Q = Q.to(original_device).type(original_type)
89
- final.append(Q)
90
- return final
91
-
92
- # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
93
- @torch.no_grad
94
- def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
95
- """
96
- Computes the eigenbases of the preconditioner using one round of power iteration
97
- followed by torch.linalg.qr decomposition.
98
- """
99
- matrix = []
100
- orth_matrix = []
101
- float_data = False
102
- original_type = original_device = None
103
- for m,o in zip(GG, Q_list):
104
- if m is None: continue
105
- assert o is not None
106
-
107
- if len(m) == 0:
108
- matrix.append([])
109
- orth_matrix.append([])
110
- continue
111
- if m.data.dtype != torch.float:
112
- original_type = m.data.dtype
113
- original_device = m.data.device
114
- matrix.append(m.data.float())
115
- orth_matrix.append(o.data.float())
116
- else:
117
- float_data = True
118
- matrix.append(m.data.float())
119
- orth_matrix.append(o.data.float())
120
-
121
- final = []
122
- for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
123
- if len(m)==0:
124
- final.append([])
125
- continue
126
- est_eig = torch.diag(o.T @ m @ o)
127
- sort_idx = torch.argsort(est_eig, descending=True)
128
- exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
129
- o = o[:,sort_idx]
130
- power_iter = m @ o
131
- Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
132
-
133
- if not float_data:
134
- Q = Q.to(original_device).type(original_type)
135
- final.append(Q)
136
-
137
- return final, exp_avg_sq
138
-
139
- Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys','sn', 'yn']
25
+ Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys']
140
26
  class ABSOAP(Transform):
141
- """SOAP but with two extra letters included in its name in order to improve converence
142
-
143
- so what you can do is choose what goes into what ,and that is supposed to be good.
27
+ """SOAP but with some extra options for testing.
28
+
29
+ .. warning::
30
+ This module is just for testing my stupid ideas.
31
+
32
+ Args:
33
+ scale_by_s - whether to scale y by s
34
+ gg1 - 1st vector into GGᵀ
35
+ gg2 - 2nd vector into GGᵀ
36
+ ema1 - vector into 1st momentum
37
+ ema2 - 2 vectors into 2nd momentum
38
+ rel1 - if True, multiplies gg1 by params
39
+ rel2 - same but for gg2
40
+ norm - if True, gg1 a and gg2 are normalized, and I need to make that into a letter
41
+
42
+ letters:
43
+ p - params
44
+ g - grad
45
+ s - param difference
46
+ y - grad difference
47
+ gy - g+y
48
+ sy - s+y
49
+ sn - s normalized
50
+ yn - y normalized
51
+ gys - g + y#g
52
+ sys - s + y#s
144
53
 
145
- new args
146
-
147
- scale by s whether to scale gradient differences by parameter differences
148
-
149
- y_to_ema2 whether to use gradient differences for exponential moving average too
150
-
151
- okay I changed these args into another ones
152
-
153
- BASICALLY THIS IS FOR MY EXPERIMENTS
154
54
  """
155
55
  def __init__(
156
56
  self,
@@ -166,8 +66,8 @@ class ABSOAP(Transform):
166
66
  alpha: float = 1,
167
67
  bias_correction: bool = True,
168
68
  scale_by_s: bool = True,
169
- first: Source='g',
170
- second: Source='g',
69
+ gg1: Source='g',
70
+ gg2: Source='g',
171
71
  ema1: Source='g',
172
72
  ema2: tuple[Source, Source] = ('g','g'),
173
73
  rel1: bool=False,
@@ -189,29 +89,27 @@ class ABSOAP(Transform):
189
89
  scale_by_s=scale_by_s,
190
90
  ema1=ema1,
191
91
  ema2=ema2,
192
- first=first,
193
- second=second,
92
+ first=gg1,
93
+ second=gg2,
194
94
  rel1=rel1, rel2=rel2,
195
95
  norm=norm,
196
96
  )
197
97
  super().__init__(defaults, uses_grad=False)
198
98
 
199
99
  @torch.no_grad
200
- def transform(self, tensors, params, grads, vars):
100
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
201
101
  updates = []
202
102
  # update preconditioners
203
- for i,(p,t) in enumerate(zip(params, tensors)):
204
- state = self.state[p]
205
- settings = self.settings[p]
103
+ for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
206
104
  beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
207
- 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(settings)
208
- scale_by_s = settings['scale_by_s']
209
- ema1 = settings['ema1']
210
- ema2 = settings['ema2']
211
- first=settings['first']
212
- second=settings['second']
213
- rel1 = settings['rel1']; rel2 = settings['rel2']
214
- norm=settings['norm']
105
+ 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(setting)
106
+ scale_by_s = setting['scale_by_s']
107
+ ema1 = setting['ema1']
108
+ ema2 = setting['ema2']
109
+ first=setting['first']
110
+ second=setting['second']
111
+ rel1 = setting['rel1']; rel2 = setting['rel2']
112
+ norm=setting['norm']
215
113
 
216
114
  if merge_small:
217
115
  t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
@@ -219,8 +117,8 @@ class ABSOAP(Transform):
219
117
  if 'g_prev' not in state:
220
118
  state['p_prev'] = p.clone()
221
119
  state['g_prev'] = t.clone()
222
- updates.append(tensors[i].clip(-0.1,0.1))
223
- continue
120
+ # updates.append(tensors[i].clip(-0.1,0.1))
121
+ # continue
224
122
 
225
123
  p_prev = state['p_prev']
226
124
  g_prev = state['g_prev']
@@ -270,11 +168,10 @@ class ABSOAP(Transform):
270
168
  t1 = t1/torch.linalg.vector_norm(t1).clip(min=1e-8) # pylint:disable=not-callable
271
169
  t2 = t2/torch.linalg.vector_norm(t2).clip(min=1e-8) # pylint:disable=not-callable
272
170
 
273
-
274
171
  # initialize state on 1st step
275
172
  if 'GG' not in state:
276
173
  state["exp_avg"] = torch.zeros_like(t)
277
- state["exp_avg_sq"] = torch.ones_like(t)
174
+ state["exp_avg_sq"] = torch.zeros_like(t)
278
175
 
279
176
  if not precondition_1d and t.ndim <= 1:
280
177
  state['GG'] = []
@@ -287,7 +184,7 @@ class ABSOAP(Transform):
287
184
  state['GG'] = None
288
185
 
289
186
  if state['GG'] is not None:
290
- update_soap_covariances_(t1, t2, GGs_=state['GG'], beta=shampoo_beta)
187
+ update_absoap_covariances_(t1, t2, GGs_=state['GG'], beta=shampoo_beta)
291
188
  state['Q'] = get_orthogonal_matrix(state['GG'])
292
189
 
293
190
  state['step'] = 0
@@ -334,7 +231,7 @@ class ABSOAP(Transform):
334
231
  if z1_projected is not None:
335
232
  update = project_back(update, state["Q"])
336
233
 
337
- if settings['bias_correction']:
234
+ if setting['bias_correction']:
338
235
  bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
339
236
  bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
340
237
  update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
@@ -349,8 +246,8 @@ class ABSOAP(Transform):
349
246
 
350
247
  # Update is done after the gradient step to avoid using current gradients in the projection.
351
248
  if state['GG'] is not None:
352
- update_soap_covariances_(t1, t2, state['GG'], shampoo_beta)
353
- if state['step'] % settings['precond_freq'] == 0:
249
+ update_absoap_covariances_(t1, t2, state['GG'], shampoo_beta)
250
+ if state['step'] % setting['precond_freq'] == 0:
354
251
  state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
355
252
 
356
253
  return updates
@@ -10,7 +10,7 @@ from ..functional import (
10
10
  ema_,
11
11
  sqrt_ema_sq_,
12
12
  )
13
- from ..lr.lr import lazy_lr
13
+ from ..step_size.lr import lazy_lr
14
14
  from ..momentum.experimental import sqrt_nag_ema_sq_
15
15
  from ..momentum.momentum import nag_
16
16
 
@@ -50,7 +50,13 @@ def adadam_(
50
50
  return None
51
51
 
52
52
  class Adadam(Module):
53
- """Adam with a diagonally preconditioned preconditioner."""
53
+ """Adam with a diagonally preconditioned preconditioner.
54
+
55
+ Verdict: I haven't tested this yet.
56
+
57
+ .. warning::
58
+ Experimental.
59
+ """
54
60
  def __init__(
55
61
  self,
56
62
  beta1: float = 0.9,
@@ -67,31 +73,32 @@ class Adadam(Module):
67
73
  self.getter = itemgetter('amsgrad','pow','debiased')
68
74
 
69
75
  @torch.no_grad
70
- def step(self, vars):
76
+ def step(self, var):
71
77
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
78
+ params = var.params
72
79
 
73
- beta1,beta2,precond_beta,eps,alpha=self.get_settings('beta1','beta2','precond_beta','eps','alpha', params=vars.params, cls=NumberList)
74
- amsgrad,pow,debiased = self.getter(self.settings[vars.params[0]])
80
+ beta1,beta2,precond_beta,eps,alpha=self.get_settings(params, 'beta1','beta2','precond_beta','eps','alpha', cls=NumberList)
81
+ amsgrad,pow,debiased = self.getter(self.settings[var.params[0]])
75
82
 
76
83
  if amsgrad:
77
- exp_avg, exp_avg_sq, exp_avg_qu, max_exp_avg_sq, max_exp_avg_qu = self.get_state('exp_avg','exp_avg_sq', 'exp_avg_qu', 'max_exp_avg_sq', 'max_exp_avg_qu', params=vars.params, cls=TensorList)
84
+ exp_avg, exp_avg_sq, exp_avg_qu, max_exp_avg_sq, max_exp_avg_qu = self.get_state(params, 'exp_avg','exp_avg_sq', 'exp_avg_qu', 'max_exp_avg_sq', 'max_exp_avg_qu', cls=TensorList)
78
85
  else:
79
- exp_avg, exp_avg_sq, exp_avg_qu = self.get_state('exp_avg','exp_avg_sq', 'exp_avg_qu', params=vars.params, cls=TensorList)
86
+ exp_avg, exp_avg_sq, exp_avg_qu = self.get_state(params, 'exp_avg','exp_avg_sq', 'exp_avg_qu', cls=TensorList)
80
87
  max_exp_avg_sq = None
81
88
  max_exp_avg_qu = None
82
89
 
83
90
  # if this is last module, update parameters in-place with slightly more efficient addcdiv_
84
- if vars.is_last:
85
- if vars.last_module_lrs is not None: alpha = alpha * vars.last_module_lrs
86
- passed_params = TensorList(vars.params)
87
- vars.stop = True
88
- vars.skip_update = True
91
+ if var.is_last:
92
+ if var.last_module_lrs is not None: alpha = alpha * var.last_module_lrs
93
+ passed_params = TensorList(var.params)
94
+ var.stop = True
95
+ var.skip_update = True
89
96
 
90
97
  else:
91
98
  passed_params = None
92
99
 
93
- vars.update = adadam_(
94
- tensors=TensorList(vars.get_update()),
100
+ var.update = adadam_(
101
+ tensors=TensorList(var.get_update()),
95
102
  exp_avg_=exp_avg,
96
103
  exp_avg_sq_=exp_avg_sq,
97
104
  exp_avg_qu_=exp_avg_qu,
@@ -108,4 +115,4 @@ class Adadam(Module):
108
115
  params_=passed_params,
109
116
  )
110
117
 
111
- return vars
118
+ return var
@@ -10,7 +10,7 @@ from ..functional import (
10
10
  ema_,
11
11
  sqrt_ema_sq_,
12
12
  )
13
- from ..lr.lr import lazy_lr
13
+ from ..step_size.lr import lazy_lr
14
14
  from ..momentum.experimental import sqrt_nag_ema_sq_
15
15
  from ..momentum.momentum import nag_
16
16
 
@@ -64,14 +64,10 @@ def adamy_(
64
64
  class AdamY(Module):
65
65
  """Adam but uses scaled gradient differences for second momentum.
66
66
 
67
- Args:
68
- beta1 (float, optional): momentum. Defaults to 0.9.
69
- beta2 (float, optional): second momentum. Defaults to 0.999.
70
- eps (float, optional): epsilon. Defaults to 1e-8.
71
- alpha (float, optional): learning rate. Defaults to 1.
72
- amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
73
- pow (float, optional): power used in second momentum power and root. Defaults to 2.
74
- debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
67
+ Verdict: I haven't tested this yet.
68
+
69
+ .. warning::
70
+ Experimental.
75
71
  """
76
72
  def __init__(
77
73
  self,
@@ -88,36 +84,36 @@ class AdamY(Module):
88
84
  self.getter = itemgetter('amsgrad','pow','debiased')
89
85
 
90
86
  @torch.no_grad
91
- def step(self, vars):
87
+ def step(self, var):
92
88
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
93
89
 
94
- beta1,beta2,eps,alpha=self.get_settings('beta1','beta2','eps','alpha', params=vars.params, cls=NumberList)
95
- amsgrad,pow,debiased = self.getter(self.settings[vars.params[0]])
90
+ beta1,beta2,eps,alpha=self.get_settings(var.params, 'beta1','beta2','eps','alpha', cls=NumberList)
91
+ amsgrad,pow,debiased = self.getter(self.settings[var.params[0]])
96
92
 
97
93
  if amsgrad:
98
- exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg','exp_avg_sq','max_exp_avg_sq', params=vars.params, cls=TensorList)
94
+ exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state(var.params,'exp_avg','exp_avg_sq','max_exp_avg_sq', cls=TensorList)
99
95
  else:
100
- exp_avg, exp_avg_sq = self.get_state('exp_avg','exp_avg_sq', params=vars.params, cls=TensorList)
96
+ exp_avg, exp_avg_sq = self.get_state(var.params, 'exp_avg','exp_avg_sq', cls=TensorList)
101
97
  max_exp_avg_sq = None
102
98
 
103
99
  # if this is last module, update parameters in-place with slightly more efficient addcdiv_
104
- if vars.is_last:
105
- if vars.last_module_lrs is not None: alpha = alpha * vars.last_module_lrs
106
- passed_params = TensorList(vars.params)
107
- vars.stop = True
108
- vars.skip_update = True
100
+ if var.is_last:
101
+ if var.last_module_lrs is not None: alpha = alpha * var.last_module_lrs
102
+ passed_params = TensorList(var.params)
103
+ var.stop = True
104
+ var.skip_update = True
109
105
 
110
106
  else:
111
107
  passed_params = None
112
108
 
113
- p_prev = self.get_state('p_prev', params=vars.params, cls=TensorList)
114
- g_prev = self.get_state('g_prev', params=vars.params, cls=TensorList)
109
+ p_prev = self.get_state(var.params, 'p_prev', cls=TensorList)
110
+ g_prev = self.get_state(var.params, 'g_prev', cls=TensorList)
115
111
 
116
112
 
117
- vars.update = adamy_(
118
- p=TensorList(vars.params),
113
+ var.update = adamy_(
114
+ p=TensorList(var.params),
119
115
  p_prev=p_prev,
120
- g=TensorList(vars.get_update()),
116
+ g=TensorList(var.get_update()),
121
117
  g_prev=g_prev,
122
118
  exp_avg_=exp_avg,
123
119
  exp_avg_sq_=exp_avg_sq,
@@ -132,4 +128,4 @@ class AdamY(Module):
132
128
  params_=passed_params,
133
129
  )
134
130
 
135
- return vars
131
+ return var
@@ -0,0 +1,149 @@
1
+ from operator import itemgetter
2
+ from functools import partial
3
+ import math
4
+ import torch
5
+
6
+ from ...core import Module, Target, Transform, apply_transform, Chainable
7
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
8
+ from ..functional import (
9
+ debias, debiased_step_size,
10
+ ema_,
11
+ sqrt_ema_sq_,
12
+ )
13
+ from ..step_size.lr import lazy_lr
14
+ from ..momentum.experimental import sqrt_nag_ema_sq_
15
+ from ..momentum.momentum import nag_
16
+
17
+
18
+ def _lambertw_newton_raphson(x: TensorList, iterations=5):
19
+ # z = torch.zeros_like(x)
20
+ # mask_neg = x < 0
21
+ # mask_pos = ~mask_neg
22
+
23
+ # z[mask_pos] = torch.log(x[mask_pos] + 1.0)
24
+
25
+ # x_neg = x[mask_neg]
26
+ # z_neg = -1.0 + torch.sqrt(2.0 * (1.0 + math.e * x_neg))
27
+ # z[mask_neg] = z_neg
28
+
29
+ # x is always positive
30
+ z = (x+1).log_()
31
+ for _ in range(iterations):
32
+ exp_z = z.exp()
33
+ numerator = z * exp_z - x
34
+ denominator = exp_z * (z + 1.0) + 1e-8
35
+ delta = numerator / denominator
36
+ z -= delta
37
+ return z
38
+
39
+ # https://github.com/gmgeorg/torchlambertw/blob/main/torchlambertw/special.py
40
+ def _lambertw_winitzki(x: TensorList):
41
+ x_log1p = x.log1p()
42
+ return x_log1p * (1.0 - x_log1p.log1p() / (2.0 + x_log1p))
43
+
44
+
45
+ def adam_lambertw_(
46
+ tensors: TensorList,
47
+ exp_avg_: TensorList,
48
+ exp_avg_xpx_: TensorList,
49
+ alpha: float | NumberList,
50
+ beta1: float | NumberList,
51
+ beta2: float | NumberList,
52
+ eps: float | NumberList,
53
+ step: int,
54
+ pow: float = 2,
55
+ debiased: bool = True,
56
+ max_exp_avg_xpx_: TensorList | None = None,
57
+ iterations: int | None = 5,
58
+
59
+ # inner args
60
+ inner: Module | None = None,
61
+ params: list[torch.Tensor] | None = None,
62
+ grads: list[torch.Tensor] | None = None,
63
+ ):
64
+ """Returns new tensors."""
65
+ tensors_abs = tensors.abs().clip_(max=20)
66
+ tensors_xpx = tensors_abs.pow_(tensors_abs)
67
+ exp_avg_xpx_.lerp_(tensors_xpx, 1-beta2)
68
+
69
+ if max_exp_avg_xpx_ is not None:
70
+ max_exp_avg_xpx_.maximum_(exp_avg_xpx_)
71
+ exp_avg_xpx_ = max_exp_avg_xpx_
72
+
73
+ if inner is not None:
74
+ assert params is not None
75
+ tensors = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
76
+
77
+ exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
78
+ if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
79
+
80
+ if iterations is None or iterations < 1: exp_avg_xpx_ = _lambertw_winitzki(exp_avg_xpx_)
81
+ else: exp_avg_xpx_ = _lambertw_newton_raphson(exp_avg_xpx_, iterations)
82
+
83
+ return (exp_avg_.lazy_mul(alpha) / exp_avg_xpx_.add_(eps))
84
+
85
+ class AdamLambertW(Transform):
86
+ """Adam but uses abs x^x and LambertW instead of square and sqrt.
87
+ The gradient will be clipped to 20 because float32 which you have to use otherwise you're PC will explode.
88
+
89
+ Args:
90
+ beta1 (float, optional): momentum. Defaults to 0.9.
91
+ beta2 (float, optional): second momentum. Defaults to 0.999.
92
+ eps (float, optional): epsilon. Defaults to 1e-8.
93
+ alpha (float, optional): learning rate. Defaults to 1.
94
+ amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
95
+ pow (float, optional): power used in second momentum power and root. Defaults to 2.
96
+ debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
97
+ iterations (int, optional): 0 or None means Winitzki approximation otherwise number of newton raphson iterations.
98
+ """
99
+ def __init__(
100
+ self,
101
+ beta1: float = 0.9,
102
+ beta2: float = 0.999,
103
+ eps: float = 1e-8,
104
+ amsgrad: bool = False,
105
+ alpha: float = 1.,
106
+ pow: float = 2,
107
+ debiased: bool = True,
108
+ iterations: int | None = 5,
109
+ inner: Chainable | None = None
110
+ ):
111
+ defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased, iterations=iterations)
112
+ super().__init__(defaults, uses_grad=False)
113
+
114
+ if inner is not None: self.set_child('inner', inner)
115
+
116
+ @torch.no_grad
117
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
118
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
119
+
120
+ beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
121
+ amsgrad,pow,debiased,iterations = itemgetter('amsgrad','pow','debiased','iterations')(settings[0])
122
+
123
+ if amsgrad:
124
+ exp_avg, exp_avg_xpx, max_exp_avg_xpx = unpack_states(states, tensors, 'exp_avg', 'exp_avg_xpx', 'max_exp_avg_xpx', cls=TensorList)
125
+ else:
126
+ exp_avg, exp_avg_xpx = unpack_states(states, tensors, 'exp_avg', 'exp_avg_xpx', cls=TensorList)
127
+ max_exp_avg_xpx = None
128
+
129
+
130
+ return adam_lambertw_(
131
+ tensors=TensorList(tensors),
132
+ exp_avg_=exp_avg,
133
+ exp_avg_xpx_=exp_avg_xpx,
134
+ alpha=alpha,
135
+ beta1=beta1,
136
+ beta2=beta2,
137
+ eps=eps,
138
+ step=step,
139
+ pow=pow,
140
+ debiased=debiased,
141
+ max_exp_avg_xpx_=max_exp_avg_xpx,
142
+ iterations=iterations,
143
+
144
+ # inner args
145
+ inner=self.children.get("inner", None),
146
+ params=params,
147
+ grads=grads,
148
+
149
+ )
@@ -2,35 +2,64 @@ from operator import itemgetter
2
2
 
3
3
  import torch
4
4
 
5
- from .line_search import LineSearch
5
+ from ..line_search import LineSearchBase
6
6
 
7
7
 
8
- class TrustRegion(LineSearch):
9
- """Basic first order trust region, re-evaluates closure with updated parameters and scales step size based on function value change"""
8
+ class AdaptiveStepSize(LineSearchBase):
9
+ """Basic first order step size adaptation method. Re-evaluates the function after stepping, if value decreased sufficiently,
10
+ step size is increased. If value increased, step size is decreased.
11
+
12
+ .. note::
13
+ This works well in some cases, but it is often prone to collapsing.
14
+ For a more robust alternative use :code:`tz.m.AdaptiveBacktracking`.
15
+
16
+ Args:
17
+ nplus (float, optional): multiplier to step size on successful steps. Defaults to 1.5.
18
+ nminus (float, optional): multiplier to step size on unsuccessful steps. Defaults to 0.75.
19
+ c (float, optional): descent condition. Defaults to 1e-4.
20
+ init (float, optional): initial step size. Defaults to 1.
21
+ backtrack (bool, optional): whether to undo the step if value increased. Defaults to True.
22
+ adaptive (bool, optional):
23
+ If enabled, when multiple consecutive steps have been successful or unsuccessful,
24
+ the corresponding multipliers are increased, otherwise they are reset. Defaults to True.
25
+
26
+
27
+ Examples:
28
+ Adagrad with trust region:
29
+
30
+ .. code-block:: python
31
+
32
+ opt = tz.Modular(
33
+ model.parameters(),
34
+ tz.m.Adagrad(),
35
+ tz.m.TrustRegion()
36
+ )
37
+
38
+ """
10
39
  def __init__(self, nplus: float=1.5, nminus: float=0.75, c: float=1e-4, init: float = 1, backtrack: bool = True, adaptive: bool = True):
11
40
  defaults = dict(nplus=nplus, nminus=nminus, c=c, init=init, backtrack=backtrack, adaptive=adaptive)
12
41
  super().__init__(defaults)
13
42
 
14
43
  @torch.no_grad
15
- def search(self, update, vars):
44
+ def search(self, update, var):
16
45
 
17
- nplus, nminus, c, init, backtrack, adaptive = itemgetter('nplus','nminus','c','init','backtrack', 'adaptive')(self.settings[vars.params[0]])
46
+ nplus, nminus, c, init, backtrack, adaptive = itemgetter('nplus','nminus','c','init','backtrack', 'adaptive')(self.settings[var.params[0]])
18
47
  step_size = self.global_state.setdefault('step_size', init)
19
48
  previous_success = self.global_state.setdefault('previous_success', False)
20
49
  nplus_mul = self.global_state.setdefault('nplus_mul', 1)
21
50
  nminus_mul = self.global_state.setdefault('nminus_mul', 1)
22
51
 
23
52
 
24
- f_0 = self.evaluate_step_size(0, vars, backward=False)
53
+ f_0 = self.evaluate_step_size(0, var, backward=False)
25
54
 
26
55
  # directional derivative (0 if c = 0 because it is not needed)
27
56
  if c == 0: d = 0
28
- else: d = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), update))
57
+ else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), update))
29
58
 
30
59
  # test step size
31
60
  sufficient_f = f_0 + c * step_size * min(d, 0) # pyright:ignore[reportArgumentType]
32
61
 
33
- f_1 = self.evaluate_step_size(step_size, vars, backward=False)
62
+ f_1 = self.evaluate_step_size(step_size, var, backward=False)
34
63
 
35
64
  proposed = step_size
36
65