torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,214 @@
1
+ """A bunch of useless modules that I hate and that didn't work"""
2
+ import torch
3
+
4
+ from ...core import Chainable, Transform, apply_transform
5
+ from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
6
+
7
+
8
+ class CosineStepSize(Transform):
9
+ """Adaptive step size based on cosine similarity
10
+
11
+ VERDICT: Useless. This is too unstable.
12
+
13
+ Args:
14
+ scale (float, optional): cosine similarity multiplier. Defaults to 0.95.
15
+ init (float, optional): initial step size. Defaults to 1.
16
+ eps (float, optional): epsilon for division stability. Defaults to 1e-12.
17
+ target_cossim (float, optional): cosine similarity needs to be above this to increase step size. Defaults to 1e-8.
18
+ inner (Chainable | None, optional):
19
+ inner modules applied after calculating cosine similarity and before step size correction. Defaults to None.
20
+ """
21
+ def __init__(self, scale:float = 0.95, init:float=1, eps:float=1e-12, inner:Chainable | None = None):
22
+ defaults = dict(scale=scale, init=init, eps=eps)
23
+ super().__init__(defaults, uses_grad=False)
24
+ if inner is not None: self.set_child('inner', inner)
25
+
26
+ @torch.no_grad
27
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
28
+ scale, init = unpack_dicts(settings, 'scale', 'init', cls=NumberList)
29
+ unpack_states(states, tensors, 'alpha', init=init, cls=NumberList) # initializes alpha to init
30
+ eps = settings[0]['eps']
31
+
32
+ tensors = as_tensorlist(tensors)
33
+ prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
34
+
35
+ tensors_norm = tensors.global_vector_norm()
36
+ cos_sim = (tensors.dot(prev) / (tensors_norm * prev.global_vector_norm()).clip(min=eps)).item()
37
+
38
+ if 'inner' in self.children:
39
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
40
+
41
+ new_alpha = []
42
+ for s, sc in zip(states, scale):
43
+ s['alpha'] *= 1 + cos_sim * sc
44
+ new_alpha.append(s['alpha'])
45
+
46
+ tensors.mul_(new_alpha)
47
+ prev.copy_(tensors)
48
+
49
+ return tensors
50
+
51
+
52
+
53
+ class CosineDebounce(Transform):
54
+ """Debouncing when cosine similarity is less than 0.
55
+
56
+ VERDICT: Useless. This doesn't help at all.
57
+
58
+ Args:
59
+ scale (float, optional): cosine similarity multiplier. Defaults to 0.95.
60
+ eps (float, optional): epsilon for division stability. Defaults to 1e-12.
61
+ inner (Chainable | None, optional):
62
+ inner modules applied after calculating cosine similarity and before debouncing correction. Defaults to None.
63
+ """
64
+ def __init__(self, scale:float = 0.95, eps:float=1e-12, damping:float=0.95, inner:Chainable | None = None):
65
+ defaults = dict(scale=scale, eps=eps, damping=damping)
66
+ super().__init__(defaults, uses_grad=False)
67
+ if inner is not None: self.set_child('inner', inner)
68
+
69
+ @torch.no_grad
70
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
71
+ scale, damping = unpack_dicts(settings, 'scale', 'damping', cls=NumberList)
72
+ eps = settings[0]['eps']
73
+
74
+ tensors = as_tensorlist(tensors)
75
+ prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList).mul_(damping)
76
+
77
+ tensors_norm = tensors.global_vector_norm()
78
+ cos_sim = (tensors.dot(prev) / (tensors_norm * prev.global_vector_norm()).clip(min=eps)).item()
79
+
80
+ if 'inner' in self.children:
81
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
82
+
83
+ if cos_sim < -eps:
84
+ undo = prev.neg().mul_(-cos_sim * scale)
85
+ comb = prev.graft(tensors).add_(tensors).graft_(prev).mul_(-cos_sim*scale)
86
+ tensors = undo.add_(comb)
87
+
88
+ prev.copy_(tensors)
89
+ return tensors
90
+
91
+
92
+
93
+ class CosineMomentum(Transform):
94
+ """Beta depends on cosine similarity. At cossim=1, beta is 0. At cossim=-1, beta is 2^power. This basically removes oscillations.
95
+
96
+ VERDICT: Useless. Worse than all other momentums.
97
+
98
+ Args:
99
+ scale (float, optional): cosine similarity multiplier. Defaults to 1.
100
+ nesterov (float, optional): whether to use nesterov momentum. Defaults to False.
101
+ power (float, optional): power for beta. Defaults to 1.
102
+ eps (float, optional): epsilon for division stability. Defaults to 1e-12.
103
+ inner (Chainable | None, optional):
104
+ inner modules applied after calculating cosine similarity and before updating exponential moving average. Defaults to None.
105
+ """
106
+ def __init__(self, scale:float = 1, nesterov: bool = False, power: float = 1, eps:float=1e-12, inner:Chainable | None = None):
107
+ defaults = dict(scale=scale, eps=eps, nesterov=nesterov, power=power)
108
+ super().__init__(defaults, uses_grad=False)
109
+ if inner is not None: self.set_child('inner', inner)
110
+
111
+ @torch.no_grad
112
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
113
+ scale, power = unpack_dicts(settings, 'scale', 'power', cls=NumberList)
114
+ eps = settings[0]['eps']
115
+ nesterov = settings[0]['nesterov']
116
+ exp_avg = unpack_states(states, tensors, 'exp_avg', cls=TensorList)
117
+
118
+ tensors = as_tensorlist(tensors)
119
+
120
+ tensors_norm = tensors.global_vector_norm()
121
+ cos_sim = (tensors.dot(exp_avg) / (tensors_norm * exp_avg.global_vector_norm()).clip(min=eps)).item()
122
+
123
+ if 'inner' in self.children:
124
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
125
+
126
+ beta = (1 - (cos_sim*scale)) ** power
127
+ if nesterov:
128
+ exp_avg.add_(tensors.mul(beta))
129
+ return tensors.add_(exp_avg)
130
+ else:
131
+ exp_avg.add_(tensors.mul_(beta))
132
+ return exp_avg.clone()
133
+
134
+
135
+ class AdaptiveDifference(Transform):
136
+ """VERDICT: Useless. Doesn't help (sort of to be expected)."""
137
+ def __init__(self, inner:Chainable | None = None):
138
+ defaults = dict()
139
+ super().__init__(defaults, uses_grad=False)
140
+ if inner is not None: self.set_child('inner', inner)
141
+
142
+ @torch.no_grad
143
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
144
+ tensors = as_tensorlist(tensors)
145
+ prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
146
+
147
+ diff = tensors - prev.graft_(tensors)
148
+ prev.copy_(tensors)
149
+
150
+ if 'inner' in self.children:
151
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
152
+
153
+ tensors.add_(diff.graft_(tensors))
154
+
155
+ return tensors
156
+
157
+ class AdaptiveDifferenceEMA(Transform):
158
+ """VERDICT: better than non-EMA but still useless."""
159
+ def __init__(self, beta=0.99, inner:Chainable | None = None):
160
+ defaults = dict(beta=beta)
161
+ super().__init__(defaults, uses_grad=False)
162
+ if inner is not None: self.set_child('inner', inner)
163
+
164
+ @torch.no_grad
165
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
166
+ tensors = as_tensorlist(tensors)
167
+ beta = unpack_dicts(settings, 'beta', cls=NumberList)
168
+ prev, diff_exp_avg = unpack_states(states, tensors, 'prev', 'diff_exp_avg', init=[tensors,torch.zeros_like], cls=TensorList)
169
+
170
+ diff = (tensors - prev.graft_(tensors)).graft_(tensors)
171
+ diff_exp_avg.lerp_(diff, 1-beta)
172
+ prev.copy_(tensors)
173
+
174
+ if 'inner' in self.children:
175
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
176
+
177
+ tensors.add_(diff_exp_avg.graft(tensors))
178
+
179
+ return tensors
180
+
181
+
182
+ class ScaledAdaptiveDifference(Transform):
183
+ """VERDICT: Useless and doesn't help."""
184
+ def __init__(self, scale=0.95, damping:float=0.99, inner:Chainable | None = None):
185
+ defaults = dict(scale=scale, damping=damping)
186
+ super().__init__(defaults, uses_grad=False)
187
+ if inner is not None: self.set_child('inner', inner)
188
+
189
+ @torch.no_grad
190
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
191
+ tensors = as_tensorlist(tensors)
192
+ scale, damping = unpack_dicts(settings, 'scale', 'damping', cls=NumberList)
193
+ prev_tensors, prev_update = unpack_states(states, tensors, 'prev', 'prev_update', init=[tensors,tensors], cls=TensorList)
194
+
195
+ cos_sim = (tensors.dot(prev_update) / (tensors.global_vector_norm() * prev_update.global_vector_norm()).clip(min=1e-10)).item()
196
+
197
+ if 'inner' in self.children:
198
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
199
+
200
+ if cos_sim > 0:
201
+ tensors.add_(prev_tensors*(cos_sim*scale))
202
+
203
+ else:
204
+ undo = prev_tensors.neg().mul_(-cos_sim*scale)
205
+ comb = prev_tensors.graft(tensors).add_(tensors).graft_(prev_tensors).mul_(-cos_sim*scale)
206
+ tensors = undo.add_(comb).graft_((tensors-prev_tensors).mul_(damping))
207
+
208
+ diff = tensors - prev_tensors.graft_(tensors)
209
+ prev_tensors.copy_(tensors)
210
+ diff.graft_(tensors)
211
+ tensors.add_(diff)
212
+ prev_update.copy_(tensors)
213
+
214
+ return tensors
@@ -0,0 +1,97 @@
1
+ import torch
2
+
3
+ from ...core import Transform
4
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
5
+
6
+
7
+ def signed_cbrt(x: TensorList) -> TensorList:
8
+ return x.sign() * x.abs().pow(1/3)
9
+
10
+ def cubic_adam_(
11
+ tensors: TensorList,
12
+ exp_avg_: TensorList,
13
+ exp_avg_sq_: TensorList,
14
+ exp_avg_cu_: TensorList,
15
+ alpha: float | NumberList,
16
+ beta1: float | NumberList,
17
+ beta2: float | NumberList,
18
+ beta3: float | NumberList,
19
+ eps: float | NumberList,
20
+ debiased: bool,
21
+ step: int,
22
+ ):
23
+ exp_avg_.lerp_(tensors, 1-beta1)
24
+ exp_avg_sq_.lerp_(tensors**2, 1-beta2)
25
+ exp_avg_cu_.lerp_(tensors**3, 1-beta3)
26
+
27
+ if debiased:
28
+ m1 = exp_avg_ / (1 - beta1 ** step)
29
+ m2 = exp_avg_sq_ / (1 - beta2 ** step)
30
+ m3 = exp_avg_cu_ / (1 - beta3 ** step)
31
+ else:
32
+ m1, m2, m3 = exp_avg_, exp_avg_sq_, exp_avg_cu_
33
+
34
+ # adam minimizes ax^2 + bx
35
+ # we are going to minimize ax^3 + bx^2 + cx
36
+ A = signed_cbrt(m3)
37
+ B = m2.sqrt()
38
+ C = m1
39
+ discriminant = B.pow(2) - 4 * A * C
40
+
41
+ denom = 2 * A
42
+ root = discriminant.clamp(min=0).sqrt_()
43
+
44
+ x0 = (-B + root) / (denom + eps)
45
+ x1 = (-B - root) / (denom + eps)
46
+
47
+ f0 = (A/3)*x0**3 + (B/2)*x0**2 + C*x0
48
+ f1 = (A/3)*x1**3 + (B/2)*x1**2 + C*x1
49
+
50
+ x_star = x0.where(f0 < f1, x1)
51
+
52
+ adam = -C / (B + eps)
53
+ x_star = adam.where(discriminant < 0, x_star)
54
+
55
+ return x_star.mul_(-alpha)
56
+
57
+ class CubicAdam(Transform):
58
+ """Adam which has 3rd momentum and minimizes a cubic polynomial.
59
+
60
+ VERDICT: can outperform Adam very slightly. Usually very similar performance.
61
+
62
+ .. warning::
63
+ Experimental.
64
+
65
+ """
66
+ def __init__(
67
+ self,
68
+ beta1: float = 0.9,
69
+ beta2: float = 0.99,
70
+ beta3: float = 0.99,
71
+ eps: float = 1e-8,
72
+ debiased:bool=True,
73
+ alpha: float = 1.,
74
+ ):
75
+ defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps,debiased=debiased,alpha=alpha)
76
+ super().__init__(defaults, uses_grad=False)
77
+
78
+ @torch.no_grad
79
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
80
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
81
+
82
+ beta1,beta2,beta3,eps,alpha=unpack_dicts(settings, 'beta1','beta2','beta3','eps','alpha', cls=NumberList)
83
+ exp_avg, exp_avg_sq, exp_avg_cu = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'exp_avg_cu', cls=TensorList)
84
+
85
+ return cubic_adam_(
86
+ tensors=TensorList(tensors),
87
+ exp_avg_=exp_avg,
88
+ exp_avg_sq_=exp_avg_sq,
89
+ exp_avg_cu_=exp_avg_cu,
90
+ alpha=alpha,
91
+ beta1=beta1,
92
+ beta2=beta2,
93
+ beta3=beta3,
94
+ eps=eps,
95
+ debiased=settings[0]['debiased'],
96
+ step=step,
97
+ )
@@ -1,13 +1,13 @@
1
1
  from typing import Literal
2
2
  import torch
3
3
  import torch_dct
4
- from .projection import Projection
4
+ from ..projections import ProjectionBase
5
5
  from ...core import Chainable
6
6
 
7
7
  def reverse_dims(t:torch.Tensor):
8
8
  return t.permute(*reversed(range(t.ndim)))
9
9
 
10
- class DCTProjection(Projection):
10
+ class DCTProjection(ProjectionBase):
11
11
  # norm description copied from pytorch docstring
12
12
  """Project update into Discrete Cosine Transform space, requires `torch_dct` library.
13
13
 
@@ -34,8 +34,8 @@ class DCTProjection(Projection):
34
34
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
35
35
 
36
36
  @torch.no_grad
37
- def project(self, tensors, var, current):
38
- settings = self.settings[var.params[0]]
37
+ def project(self, tensors, params, grads, loss, states, settings, current):
38
+ settings = settings[0]
39
39
  dims = settings['dims']
40
40
  norm = settings['norm']
41
41
 
@@ -54,18 +54,18 @@ class DCTProjection(Projection):
54
54
  return projected
55
55
 
56
56
  @torch.no_grad
57
- def unproject(self, tensors, var, current):
58
- settings = self.settings[var.params[0]]
57
+ def unproject(self, projected_tensors, params, grads, loss, projected_states, projected_settings, current):
58
+ settings = projected_settings[0]
59
59
  dims = settings['dims']
60
60
  norm = settings['norm']
61
61
 
62
62
  unprojected = []
63
- for u in tensors:
64
- dim = min(u.ndim, dims)
63
+ for t in projected_tensors:
64
+ dim = min(t.ndim, dims)
65
65
 
66
- if dim == 1: idct = torch_dct.idct(u, norm = norm)
67
- elif dim == 2: idct = torch_dct.idct_2d(u, norm=norm)
68
- elif dim == 3: idct = torch_dct.idct_3d(u, norm=norm)
66
+ if dim == 1: idct = torch_dct.idct(t, norm = norm)
67
+ elif dim == 2: idct = torch_dct.idct_2d(t, norm=norm)
68
+ elif dim == 3: idct = torch_dct.idct_3d(t, norm=norm)
69
69
  else: raise ValueError(f"Unsupported number of dimensions {dim}")
70
70
 
71
71
  unprojected.append(reverse_dims(idct))
@@ -23,7 +23,10 @@ def _cosine_similarity(x, y):
23
23
 
24
24
  class EigenDescent(Module):
25
25
  """
26
- Uses eigenvectors corresponding to certain eigenvalues. Please note that this is experimental and isn't guaranteed to work.
26
+ Uses eigenvectors corresponding to certain eigenvalues. For now they are just extracted from hessian.
27
+
28
+ .. warning::
29
+ Experimental.
27
30
 
28
31
  Args:
29
32
  mode (str, optional):
@@ -4,13 +4,17 @@ import warnings
4
4
  import torch
5
5
 
6
6
  from ...core import Module
7
- from ...utils import vec_to_tensors, vec_to_tensors_
7
+ from ...utils import vec_to_tensors, vec_to_tensors_, as_tensorlist
8
8
 
9
9
 
10
10
  class ExponentialTrajectoryFit(Module):
11
- """A method. Please note that this is experimental and isn't guaranteed to work."""
12
- def __init__(self, step_size=1e-3):
13
- defaults = dict(step_size = step_size)
11
+ """A method.
12
+
13
+ .. warning::
14
+ Experimental.
15
+ """
16
+ def __init__(self, step_size=1e-2, adaptive:bool=True):
17
+ defaults = dict(step_size = step_size,adaptive=adaptive)
14
18
  super().__init__(defaults)
15
19
 
16
20
  @torch.no_grad
@@ -18,11 +22,17 @@ class ExponentialTrajectoryFit(Module):
18
22
  closure = var.closure
19
23
  assert closure is not None
20
24
  step_size = self.settings[var.params[0]]['step_size']
25
+ adaptive = self.settings[var.params[0]]['adaptive']
26
+
21
27
 
22
28
  # 1. perform 3 GD steps to obtain 4 points
23
29
  points = [torch.cat([p.view(-1) for p in var.params])]
24
30
  for i in range(3):
25
- if i == 0: grad = var.get_grad()
31
+ if i == 0:
32
+ grad = var.get_grad()
33
+ if adaptive:
34
+ step_size /= as_tensorlist(grad).abs().global_mean().clip(min=1e-4)
35
+
26
36
  else:
27
37
  with torch.enable_grad(): closure()
28
38
  grad = [cast(torch.Tensor, p.grad) for p in var.params]
@@ -67,9 +77,14 @@ class ExponentialTrajectoryFit(Module):
67
77
 
68
78
 
69
79
  class ExponentialTrajectoryFitV2(Module):
70
- """Should be better than one above, except it isn't. Please note that this is experimental and isn't guaranteed to work."""
71
- def __init__(self, step_size=1e-3, num_steps: int= 4):
72
- defaults = dict(step_size = step_size, num_steps=num_steps)
80
+ """Should be better than one above, except it isn't.
81
+
82
+ .. warning::
83
+ Experimental.
84
+
85
+ """
86
+ def __init__(self, step_size=1e-3, num_steps: int= 4, adaptive:bool=True):
87
+ defaults = dict(step_size = step_size, num_steps=num_steps, adaptive=adaptive)
73
88
  super().__init__(defaults)
74
89
 
75
90
  @torch.no_grad
@@ -78,9 +93,13 @@ class ExponentialTrajectoryFitV2(Module):
78
93
  assert closure is not None
79
94
  step_size = self.settings[var.params[0]]['step_size']
80
95
  num_steps = self.settings[var.params[0]]['num_steps']
96
+ adaptive = self.settings[var.params[0]]['adaptive']
81
97
 
82
98
  # 1. perform 3 GD steps to obtain 4 points (or more)
83
99
  grad = var.get_grad()
100
+ if adaptive:
101
+ step_size /= as_tensorlist(grad).abs().global_mean().clip(min=1e-4)
102
+
84
103
  points = [torch.cat([p.view(-1) for p in var.params])]
85
104
  point_grads = [torch.cat([g.view(-1) for g in grad])]
86
105
 
@@ -132,7 +151,11 @@ def _fit_exponential(y0, y1, y2):
132
151
  return A, B, r
133
152
 
134
153
  class PointwiseExponential(Module):
135
- """A stupid method (for my youtube channel). Please note that this is experimental and isn't guaranteed to work."""
154
+ """A stupid method (for my youtube channel).
155
+
156
+ .. warning::
157
+ Experimental.
158
+ """
136
159
  def __init__(self, step_size: float = 1e-3, reg: float = 1e-2, steps = 10000):
137
160
  defaults = dict(reg=reg, steps=steps, step_size=step_size)
138
161
  super().__init__(defaults)
@@ -0,0 +1,113 @@
1
+ from operator import itemgetter
2
+ from functools import partial
3
+ import math
4
+ import torch
5
+
6
+ from ...core import Module, Target, Transform, apply_transform, Chainable
7
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
8
+ from ..functional import (
9
+ debias, debiased_step_size,
10
+ ema_,
11
+ sqrt_ema_sq_,
12
+ )
13
+ from ..step_size.lr import lazy_lr
14
+ from ..momentum.experimental import sqrt_nag_ema_sq_
15
+ from ..momentum.momentum import nag_
16
+
17
+
18
+ def exp_adam_(
19
+ tensors: TensorList,
20
+ exp_avg_: TensorList,
21
+ exp_avg_exp_: TensorList,
22
+ alpha: float | NumberList,
23
+ beta1: float | NumberList,
24
+ beta2: float | NumberList,
25
+ eps: float | NumberList,
26
+ step: int,
27
+ pow: float = 2,
28
+ debiased: bool = True,
29
+ max_exp_avg_exp_: TensorList | None = None,
30
+
31
+ # inner args
32
+ inner: Module | None = None,
33
+ params: list[torch.Tensor] | None = None,
34
+ grads: list[torch.Tensor] | None = None,
35
+ ):
36
+ """Returns new tensors."""
37
+ tensors_exp = tensors.abs().clip_(max=math.log(torch.finfo(tensors[0].dtype).max) / 2).exp_()
38
+ exp_avg_exp_.lerp_(tensors_exp, 1-beta2)
39
+
40
+ if max_exp_avg_exp_ is not None:
41
+ max_exp_avg_exp_.maximum_(exp_avg_exp_)
42
+ exp_avg_exp_ = max_exp_avg_exp_
43
+
44
+ if inner is not None:
45
+ assert params is not None
46
+ tensors = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
47
+
48
+ exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
49
+ if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
50
+ return (exp_avg_.lazy_mul(alpha) / exp_avg_exp_.log().add_(eps))
51
+
52
+ class ExpAdam(Transform):
53
+ """Adam but uses abs exp and log instead of square and sqrt.
54
+ The gradient will be clipped to half the maximum value representable by its dtype (around 50 for float32)
55
+
56
+ Args:
57
+ beta1 (float, optional): momentum. Defaults to 0.9.
58
+ beta2 (float, optional): second momentum. Defaults to 0.999.
59
+ eps (float, optional): epsilon. Defaults to 1e-8.
60
+ alpha (float, optional): learning rate. Defaults to 1.
61
+ amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
62
+ pow (float, optional): power used in second momentum power and root. Defaults to 2.
63
+ debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
64
+ """
65
+ def __init__(
66
+ self,
67
+ beta1: float = 0.9,
68
+ beta2: float = 0.999,
69
+ eps: float = 1e-8,
70
+ amsgrad: bool = False,
71
+ alpha: float = 1.,
72
+ pow: float = 2,
73
+ debiased: bool = True,
74
+ inner: Chainable | None = None
75
+ ):
76
+ defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
77
+ super().__init__(defaults, uses_grad=False)
78
+
79
+ if inner is not None: self.set_child('inner', inner)
80
+
81
+ @torch.no_grad
82
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
83
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
84
+
85
+ beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
86
+ amsgrad,pow,debiased = itemgetter('amsgrad','pow','debiased')(settings[0])
87
+
88
+ if amsgrad:
89
+ exp_avg, exp_avg_exp, max_exp_avg_exp = unpack_states(states, tensors, 'exp_avg', 'exp_avg_exp', 'max_exp_avg_exp', cls=TensorList)
90
+ else:
91
+ exp_avg, exp_avg_exp = unpack_states(states, tensors, 'exp_avg', 'exp_avg_exp', cls=TensorList)
92
+ max_exp_avg_exp = None
93
+
94
+
95
+ return exp_adam_(
96
+ tensors=TensorList(tensors),
97
+ exp_avg_=exp_avg,
98
+ exp_avg_exp_=exp_avg_exp,
99
+ alpha=alpha,
100
+ beta1=beta1,
101
+ beta2=beta2,
102
+ eps=eps,
103
+ step=step,
104
+ pow=pow,
105
+ debiased=debiased,
106
+ max_exp_avg_exp_=max_exp_avg_exp,
107
+
108
+ # inner args
109
+ inner=self.children.get("inner", None),
110
+ params=params,
111
+ grads=grads,
112
+
113
+ )