torchzero 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tests/test_identical.py +1 -1
  2. torchzero/__init__.py +3 -1
  3. torchzero/_minimize/__init__.py +0 -0
  4. torchzero/_minimize/methods.py +95 -0
  5. torchzero/_minimize/minimize.py +518 -0
  6. torchzero/core/__init__.py +5 -5
  7. torchzero/core/chain.py +2 -1
  8. torchzero/core/functional.py +2 -1
  9. torchzero/core/module.py +75 -4
  10. torchzero/core/transform.py +6 -5
  11. torchzero/linalg/eigh.py +116 -68
  12. torchzero/linalg/linear_operator.py +1 -0
  13. torchzero/linalg/orthogonalize.py +60 -5
  14. torchzero/linalg/sketch.py +39 -0
  15. torchzero/modules/__init__.py +1 -0
  16. torchzero/modules/adaptive/adagrad.py +2 -0
  17. torchzero/modules/adaptive/adam.py +5 -1
  18. torchzero/modules/adaptive/adan.py +3 -0
  19. torchzero/modules/adaptive/ggt.py +20 -18
  20. torchzero/modules/adaptive/lion.py +3 -1
  21. torchzero/modules/adaptive/mars.py +6 -5
  22. torchzero/modules/adaptive/msam.py +3 -0
  23. torchzero/modules/adaptive/rmsprop.py +2 -0
  24. torchzero/modules/adaptive/rprop.py +9 -7
  25. torchzero/modules/adaptive/shampoo.py +9 -1
  26. torchzero/modules/adaptive/soap.py +32 -29
  27. torchzero/modules/basis/__init__.py +2 -0
  28. torchzero/modules/basis/ggt_basis.py +199 -0
  29. torchzero/modules/basis/soap_basis.py +254 -0
  30. torchzero/modules/clipping/ema_clipping.py +32 -27
  31. torchzero/modules/clipping/growth_clipping.py +1 -0
  32. torchzero/modules/experimental/__init__.py +1 -6
  33. torchzero/modules/experimental/coordinate_momentum.py +2 -0
  34. torchzero/modules/experimental/cubic_adam.py +4 -0
  35. torchzero/modules/grad_approximation/__init__.py +3 -2
  36. torchzero/modules/least_squares/gn.py +6 -0
  37. torchzero/modules/misc/gradient_accumulation.py +1 -0
  38. torchzero/modules/misc/misc.py +6 -0
  39. torchzero/modules/momentum/averaging.py +6 -0
  40. torchzero/modules/momentum/momentum.py +13 -9
  41. torchzero/modules/ops/__init__.py +0 -1
  42. torchzero/modules/ops/accumulate.py +4 -0
  43. torchzero/modules/ops/higher_level.py +6 -1
  44. torchzero/modules/second_order/inm.py +4 -0
  45. torchzero/modules/second_order/newton.py +11 -3
  46. torchzero/modules/second_order/newton_cg.py +7 -3
  47. torchzero/modules/second_order/nystrom.py +14 -19
  48. torchzero/modules/second_order/rsn.py +37 -6
  49. torchzero/modules/trust_region/trust_region.py +2 -1
  50. torchzero/utils/benchmarks/logistic.py +33 -18
  51. torchzero/utils/optuna_tools.py +1 -1
  52. torchzero/utils/params.py +13 -1
  53. torchzero/utils/tensorlist.py +2 -2
  54. {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/METADATA +1 -1
  55. {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/RECORD +58 -55
  56. torchzero/modules/experimental/adanystrom.py +0 -258
  57. torchzero/modules/experimental/common_directions_whiten.py +0 -142
  58. torchzero/modules/experimental/eigen_sr1.py +0 -182
  59. torchzero/modules/experimental/eigengrad.py +0 -207
  60. /torchzero/modules/{experimental → grad_approximation}/spsa1.py +0 -0
  61. {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/WHEEL +0 -0
  62. {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,254 @@
1
+ from operator import itemgetter
2
+ import warnings
3
+
4
+ import torch
5
+
6
+ from ...core import TensorTransform, Chainable, Module
7
+ from ..adaptive import Adam
8
+ from ...utils import unpack_dicts, unpack_states, TensorList, NumberList, set_storage_
9
+ from ...modules.adaptive.shampoo import _merge_small_dims, _unmerge_small_dims
10
+ from ...linalg import torch_linalg
11
+ from ..adaptive.soap import get_orthogonal_matrix, project, project_back, update_soap_covariances_
12
+
13
+ # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
14
+ @torch.no_grad
15
+ def get_orthogonal_matrix_QR(grad_sqs: list[torch.Tensor], GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
16
+ """
17
+ Computes the eigenbases of the preconditioner using one round of power iteration
18
+ followed by torch.linalg.qr decomposition.
19
+ """
20
+ final = []
21
+
22
+ for ind, (M, O) in enumerate(zip(GG, Q_list)):
23
+
24
+ # skip 1d or large dims
25
+ if M is None:
26
+ final.append(None)
27
+ continue
28
+
29
+ assert O is not None
30
+
31
+ est_eig = torch.diagonal(O.T @ M @ O)
32
+ sort_idx = torch.argsort(est_eig, descending=True)
33
+ grad_sqs = [s.index_select(ind, sort_idx) for s in grad_sqs]
34
+
35
+ power_iter = M @ O[:, sort_idx]
36
+ Q, _ = torch_linalg.qr(power_iter.to(torch.float32), retry_float64=True)
37
+ Q = Q.to(power_iter.dtype)
38
+
39
+ final.append(Q)
40
+
41
+ return final, grad_sqs
42
+
43
+ class SOAPBasis(TensorTransform):
44
+ """
45
+ Run another optimizer in Shampoo eigenbases.
46
+
47
+ Note:
48
+ the buffers of the ``basis_opt`` are re-projected whenever basis changes. The reprojection logic is not implemented on all modules. Some supported modules are:
49
+
50
+ ``Adagrad``, ``Adam``, ``Adan``, ``Lion``, ``MARSCorrection``, ``MSAMMomentum``, ``RMSprop``, ``EMA``, ``HeavyBall``, ``NAG``, ``ClipNormByEMA``, ``ClipValueByEMA``, ``NormalizeByEMA``, ``ClipValueGrowth``, ``CoordinateMomentum``, ``CubicAdam``.
51
+
52
+ Additionally most modules with no internal buffers are supported, e.g. ``Cautious``, ``Sign``, ``ClipNorm``, ``Orthogonalize``, etc. However modules that use weight values, such as ``WeighDecay`` can't be supported, as weights can't be projected.
53
+
54
+ Also, if you say use ``EMA`` on output of ``Pow(2)``, the exponential average will be reprojected as gradient and not as squared gradients. Use modules like ``EMASquared``, ``SqrtEMASquared`` to get correct reprojections.
55
+
56
+ Args:
57
+ basis_opt (Chainable): module or modules to run in Shampoo eigenbases.
58
+ shampoo_beta (float | None, optional):
59
+ beta for covariance matrices accumulators. Can be None, then it just sums them like Adagrad (which works worse). Defaults to 0.95.
60
+ precond_freq (int, optional): How often to update the preconditioner. Defaults to 10.
61
+ merge_small (bool, optional): Whether to merge small dims. Defaults to True.
62
+ max_dim (int, optional): Won't precondition dims larger than this. Defaults to 10_000.
63
+ precondition_1d (bool, optional):
64
+ Whether to precondition 1d params (SOAP paper sets this to False). Defaults to True.
65
+ inner (Chainable | None, optional):
66
+ output of this module is projected and ``basis_opt`` will run on it, but preconditioners are updated
67
+ from original gradients.
68
+
69
+ Examples:
70
+ SOAP with MARS and AMSGrad:
71
+ ```python
72
+ opt = tz.Optimizer(
73
+ model.parameters(),
74
+ tz.m.SOAPBasis([tz.m.MARSCorrection(0.95), tz.m.Adam(0.95, 0.95, amsgrad=True)]),
75
+ tz.m.LR(1e-3)
76
+ )
77
+ ```
78
+
79
+ LaProp in Shampoo eigenbases (SOLP):
80
+ ```python
81
+
82
+ # we define LaProp through other modules, moved it out for brevity
83
+ laprop = (
84
+ tz.m.RMSprop(0.95),
85
+ tz.m.Debias(beta1=None, beta2=0.95),
86
+ tz.m.EMA(0.95),
87
+ tz.m.Debias(beta1=0.95, beta2=None),
88
+ )
89
+
90
+ opt = tz.Optimizer(
91
+ model.parameters(),
92
+ tz.m.SOAPBasis(laprop),
93
+ tz.m.LR(1e-3)
94
+ )
95
+ ```
96
+
97
+ Lion in Shampoo eigenbases (works kinda well):
98
+ ```python
99
+
100
+ opt = tz.Optimizer(
101
+ model.parameters(),
102
+ tz.m.SOAPBasis(tz.m.Lion()),
103
+ tz.m.LR(1e-3)
104
+ )
105
+ ```
106
+ """
107
+ def __init__(
108
+ self,
109
+ basis_opt: Chainable,
110
+ shampoo_beta: float | None = 0.95,
111
+ precond_freq: int = 10,
112
+ merge_small: bool = True,
113
+ max_dim: int = 4096,
114
+ precondition_1d: bool = True,
115
+ inner: Chainable | None = None,
116
+ ):
117
+ defaults = locals().copy()
118
+ del defaults['self'], defaults["inner"], defaults["basis_opt"]
119
+
120
+ super().__init__(defaults)
121
+ self.set_child("inner", inner)
122
+ self.set_child("basis_opt", basis_opt)
123
+
124
+ @torch.no_grad
125
+ def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
126
+ if setting["merge_small"]:
127
+ tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
128
+
129
+ state["exp_avg_proj"] = torch.zeros_like(tensor)
130
+ state["exp_avg_sq_proj"] = torch.zeros_like(tensor)
131
+
132
+ if tensor.ndim <= 1 and not setting["precondition_1d"]:
133
+ state['GG'] = []
134
+
135
+ else:
136
+ max_dim = setting["max_dim"]
137
+ state['GG'] = [
138
+ torch.zeros(s, s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
139
+ ]
140
+
141
+ # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
142
+ if len([i is not None for i in state['GG']]) == 0:
143
+ state['GG'] = None
144
+
145
+ # first covariance accumulation
146
+ if state['GG'] is not None:
147
+ update_soap_covariances_(tensor, GGs_=state['GG'], beta=setting["shampoo_beta"])
148
+
149
+ # get projection matrix with first gradients with eigh
150
+ try: state['Q'] = get_orthogonal_matrix(state['GG'])
151
+ except torch.linalg.LinAlgError as e:
152
+ warnings.warn(f"torch.linalg.eigh raised an error when initializing SOAP Q matrices on 1st step, diagonal preconditioning will be used for this parameter. The error was:\n{e}")
153
+ state["GG"] = None
154
+
155
+ state['step'] = 0
156
+
157
+
158
+ # no update to avoid running merge_dims twice
159
+
160
+ @torch.no_grad
161
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
162
+ # note
163
+ # do not modify tensors in-place
164
+ # because they are used to update preconditioner at the end
165
+
166
+ steps = [s["step"] for s in states]
167
+ if any(s == 0 for s in steps):
168
+ # skip 1st update so to avoid using current gradient in the projection
169
+ # I scale it instead to avoid issues with further modules
170
+ for s in states: s["step"] += 1
171
+ return TensorList(tensors).clamp(-0.1, 0.1)
172
+ # return TensorList(tensors).zero_()
173
+
174
+ merged_updates = [] # for when exp_avg is maintained unprojected
175
+ merged_grads = [] # this doesn't go into preconditioner
176
+ projected = []
177
+
178
+ # -------------------------------- inner step -------------------------------- #
179
+ updates = tensors
180
+ has_inner = "inner" in self.children
181
+ if has_inner:
182
+ updates = self.inner_step_tensors("inner", updates, clone=True,
183
+ params=params, grads=grads, loss=loss)
184
+
185
+ # ---------------------------------- project --------------------------------- #
186
+ for grad, update, state, setting in zip(tensors, updates, states, settings):
187
+ if setting["merge_small"]:
188
+ update, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(update, setting["max_dim"])
189
+ if has_inner: # grad is a different tensor, merge it too
190
+ grad, _, _ = _merge_small_dims(grad, setting["max_dim"])
191
+ else: # in this case update is still just grad
192
+ grad = update
193
+
194
+ merged_updates.append(update)
195
+ merged_grads.append(grad)
196
+
197
+ if state['GG'] is not None:
198
+ update = project(update, state['Q'])
199
+
200
+ projected.append(update)
201
+
202
+
203
+ # ------------------------ run opt in projected space ----------------------- #
204
+ dirs_proj = self.inner_step_tensors("basis_opt", tensors=projected, clone=True, grads=projected)
205
+
206
+ # ------------------------------- project back ------------------------------- #
207
+ dirs: list[torch.Tensor] = []
208
+ for dir, state, setting in zip(dirs_proj, states, settings):
209
+ if state['GG'] is not None:
210
+ dir = project_back(dir, state['Q'])
211
+
212
+ if setting["merge_small"]:
213
+ dir = _unmerge_small_dims(dir, state['flat_sizes'], state['sort_idxs'])
214
+
215
+ dirs.append(dir)
216
+
217
+ # -------------------------- update preconditioners -------------------------- #
218
+ # Update is done after the gradient step to avoid using current gradients in the projection.
219
+
220
+ grad_buffs = self.get_child_projected_buffers("basis_opt", "grad")
221
+ grad_sq_buffs = self.get_child_projected_buffers("basis_opt", ["grad_sq", "grad_cu"])
222
+
223
+ for i, (grad, state, setting) in enumerate(zip(merged_grads, states, settings)):
224
+ if state['GG'] is not None:
225
+
226
+ # lerp covariances
227
+ update_soap_covariances_(grad, state['GG'], beta=setting["shampoo_beta"])
228
+
229
+ # (state['step'] - 1) since we start updating on 2nd step
230
+ if (state['step'] - 1) % setting['precond_freq'] == 0:
231
+ g_buffs = [b[i] for b in grad_buffs]
232
+ g_sq_buffs = [b[i] for b in grad_sq_buffs]
233
+
234
+ # unproject grad buffers before updating
235
+ g_buffs_unproj = [project_back(buff, state["Q"]) for buff in g_buffs]
236
+
237
+ # update projection matrix and exp_avg_sq_proj
238
+ try:
239
+ state['Q'], g_sq_buffs_new = get_orthogonal_matrix_QR(
240
+ g_sq_buffs, state['GG'], state['Q'])
241
+
242
+ for b_old, b_new in zip(g_sq_buffs, g_sq_buffs_new):
243
+ set_storage_(b_old, b_new)
244
+
245
+ # re-project grad buffers
246
+ for b_proj, b_unproj in zip(g_buffs, g_buffs_unproj):
247
+ set_storage_(b_proj, project(b_unproj, state["Q"]))
248
+
249
+ except torch.linalg.LinAlgError:
250
+ pass
251
+
252
+ state["step"] += 1
253
+
254
+ return dirs
@@ -27,60 +27,67 @@ class ClipNormByEMA(TensorTransform):
27
27
  self,
28
28
  beta=0.99,
29
29
  ord: Metrics = 2,
30
- eps=1e-6,
31
30
  tensorwise:bool=True,
32
31
  max_ema_growth: float | None = 1.5,
33
- ema_init: Literal['zeros', 'update'] = 'zeros',
32
+ init: float = 0.0,
33
+ min_norm: float = 1e-6,
34
+
34
35
  inner: Chainable | None = None,
35
36
  ):
36
- defaults = dict(beta=beta, ord=ord, tensorwise=tensorwise, ema_init=ema_init, eps=eps, max_ema_growth=max_ema_growth)
37
+ defaults = dict(beta=beta, ord=ord, tensorwise=tensorwise, init=init, min_norm=min_norm, max_ema_growth=max_ema_growth)
37
38
  super().__init__(defaults, inner=inner)
39
+ self.add_projected_keys("grad", "exp_avg")
38
40
 
39
41
  @torch.no_grad
40
42
  def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
41
43
  tensors = TensorList(tensors)
42
- ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(settings[0])
44
+ eps = torch.finfo(tensors[0].dtype).tiny * 2
45
+ ord, tensorwise, init, max_ema_growth = itemgetter('ord', 'tensorwise', 'init', 'max_ema_growth')(settings[0])
43
46
 
44
- beta, eps = unpack_dicts(settings, 'beta', 'eps', cls=NumberList)
47
+ beta, min_norm = unpack_dicts(settings, 'beta', 'min_norm', cls=NumberList)
45
48
 
46
- ema = unpack_states(states, tensors, 'ema', init = (torch.zeros_like if ema_init=='zeros' else tensors), cls=TensorList)
49
+ exp_avg = unpack_states(states, tensors, 'exp_avg', init = lambda x: torch.full_like(x, init), cls=TensorList)
47
50
 
48
- ema.lerp_(tensors, 1-beta)
51
+ exp_avg.lerp_(tensors, 1-beta)
49
52
 
53
+ # ----------------------------- tensorwise update ---------------------------- #
50
54
  if tensorwise:
51
- ema_norm = ema.metric(ord)
55
+ tensors_norm = tensors.norm(ord)
56
+ ema_norm = exp_avg.metric(ord)
52
57
 
53
58
  # clip ema norm growth
54
59
  if max_ema_growth is not None:
55
60
  prev_ema_norm = unpack_states(states, tensors, 'prev_ema_norm', init=ema_norm, cls=TensorList)
56
- allowed_norm = (prev_ema_norm * max_ema_growth).clip(min=1e-6)
61
+ allowed_norm = (prev_ema_norm * max_ema_growth).clip(min=min_norm)
62
+
57
63
  ema_denom = (ema_norm / allowed_norm).clip(min=1)
58
- ema.div_(ema_denom)
64
+ exp_avg.div_(ema_denom)
59
65
  ema_norm.div_(ema_denom)
66
+
60
67
  prev_ema_norm.set_(ema_norm)
61
68
 
62
- tensors_norm = tensors.norm(ord)
63
- denom = tensors_norm / ema_norm.clip(min=eps)
64
- if self.NORMALIZE: denom.clip_(min=eps)
65
- else: denom.clip_(min=1)
66
69
 
70
+ # ------------------------------- global update ------------------------------ #
67
71
  else:
68
- ema_norm = ema.global_metric(ord)
72
+ tensors_norm = tensors.global_metric(ord)
73
+ ema_norm = exp_avg.global_metric(ord)
69
74
 
70
75
  # clip ema norm growth
71
76
  if max_ema_growth is not None:
72
77
  prev_ema_norm = self.global_state.setdefault('prev_ema_norm', ema_norm)
73
- allowed_norm = prev_ema_norm * max_ema_growth
78
+ allowed_norm = (prev_ema_norm * max_ema_growth).clip(min=min_norm[0])
79
+
74
80
  if ema_norm > allowed_norm:
75
- ema.div_(ema_norm / allowed_norm)
81
+ exp_avg.div_(ema_norm / allowed_norm)
76
82
  ema_norm = allowed_norm
83
+
77
84
  prev_ema_norm.set_(ema_norm)
78
85
 
79
- tensors_norm = tensors.global_metric(ord)
80
- denom = tensors_norm / ema_norm.clip(min=eps[0])
81
- if self.NORMALIZE: denom.clip_(min=eps[0])
82
- else: denom.clip_(min=1)
83
86
 
87
+ # ------------------- compute denominator to clip/normalize ------------------ #
88
+ denom = tensors_norm / ema_norm.clip(min=eps)
89
+ if self.NORMALIZE: denom.clip_(min=eps)
90
+ else: denom.clip_(min=1)
84
91
  self.global_state['denom'] = denom
85
92
 
86
93
  @torch.no_grad
@@ -121,7 +128,7 @@ class ClipValueByEMA(TensorTransform):
121
128
  def __init__(
122
129
  self,
123
130
  beta=0.99,
124
- init: Literal['zeros', 'update'] = 'zeros',
131
+ init: float = 0,
125
132
 
126
133
  inner: Chainable | None = None,
127
134
  exp_avg_tfm:Chainable | None=None,
@@ -130,12 +137,10 @@ class ClipValueByEMA(TensorTransform):
130
137
  super().__init__(defaults, inner=inner)
131
138
 
132
139
  self.set_child('exp_avg', exp_avg_tfm)
140
+ self.add_projected_keys("grad", "exp_avg")
133
141
 
134
142
  def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
135
- if setting["init"] == "zeros":
136
- state["exp_avg"] = torch.zeros_like(tensor)
137
- else:
138
- state["exp_avg"] = tensor.abs()
143
+ state["exp_avg"] = tensor.abs() * setting["init"]
139
144
 
140
145
  @torch.no_grad
141
146
  def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
@@ -153,4 +158,4 @@ class ClipValueByEMA(TensorTransform):
153
158
  self.inner_step_tensors("exp_avg", exp_avg, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
154
159
 
155
160
  tensors.clip_(-exp_avg, exp_avg)
156
- return tensors
161
+ return tensors
@@ -30,6 +30,7 @@ class ClipValueGrowth(TensorTransform):
30
30
  ):
31
31
  defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay)
32
32
  super().__init__(defaults)
33
+ self.add_projected_keys("grad", "prev")
33
34
 
34
35
 
35
36
  def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
@@ -1,13 +1,9 @@
1
- """Those are various ideas of mine plus some other modules that I decided not to move to other sub-packages for whatever reason. This is generally less tested and shouldn't be used."""
2
- from .adanystrom import AdaNystrom
3
- from .common_directions_whiten import CommonDirectionsWhiten
1
+ """Those are various ideas of mine plus some other modules that I decided not to move to other sub-packages for whatever reason. This is generally less tested."""
4
2
  from .coordinate_momentum import CoordinateMomentum
5
3
  from .cubic_adam import CubicAdam, SubspaceCubicAdam
6
4
  from .curveball import CurveBall
7
- from .eigen_sr1 import EigenSR1
8
5
 
9
6
  # from dct import DCTProjection
10
- from .eigengrad import Eigengrad
11
7
  from .fft import FFTProjection
12
8
  from .gradmin import GradMin
13
9
  from .higher_order_newton import HigherOrderNewton
@@ -16,5 +12,4 @@ from .newton_solver import NewtonSolver
16
12
  from .newtonnewton import NewtonNewton
17
13
  from .reduce_outward_lr import ReduceOutwardLR
18
14
  from .scipy_newton_cg import ScipyNewtonCG
19
- from .spsa1 import SPSA1
20
15
  from .structural_projections import BlockPartition, TensorizeProjection
@@ -29,6 +29,8 @@ class CoordinateMomentum(TensorTransform):
29
29
  defaults = dict(p=p)
30
30
  super().__init__(defaults)
31
31
 
32
+ self.add_projected_keys("grad", "velocity")
33
+
32
34
  @torch.no_grad
33
35
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
34
36
  p = NumberList(s['p'] for s in settings)
@@ -88,6 +88,10 @@ class CubicAdam(TensorTransform):
88
88
  defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps,debiased=debiased,alpha=alpha,mode=mode)
89
89
  super().__init__(defaults)
90
90
 
91
+ self.add_projected_keys("grad", "exp_avg")
92
+ self.add_projected_keys("grad_sq", "exp_avg_sq")
93
+ self.add_projected_keys("grad_cu", "exp_avg_cu")
94
+
91
95
  @torch.no_grad
92
96
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
93
97
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
@@ -1,4 +1,5 @@
1
- from .grad_approximator import GradApproximator, GradTarget
2
1
  from .fdm import FDM
3
- from .rfdm import RandomizedFDM, MeZO, SPSA, RDSA, GaussianSmoothing
4
2
  from .forward_gradient import ForwardGradient
3
+ from .grad_approximator import GradApproximator, GradTarget
4
+ from .rfdm import RDSA, SPSA, GaussianSmoothing, MeZO, RandomizedFDM
5
+ from .spsa1 import SPSA1
@@ -1,3 +1,5 @@
1
+ import warnings
2
+
1
3
  import torch
2
4
 
3
5
  from ...core import Chainable, Transform
@@ -129,6 +131,10 @@ class GaussNewton(Transform):
129
131
  r = objective.get_loss(backward=False) # n_residuals
130
132
  assert isinstance(r, torch.Tensor)
131
133
 
134
+ if r.numel() == 1:
135
+ r = r.view(1,1)
136
+ warnings.warn("Gauss-newton got a single residual. Make sure objective function returns a vector of residuals.")
137
+
132
138
  # set sum of squares scalar loss and it's gradient to objective
133
139
  objective.loss = r.pow(2).sum()
134
140
 
@@ -35,6 +35,7 @@ class GradientAccumulation(Module):
35
35
  def __init__(self, n: int, mean=True, stop=True):
36
36
  defaults = dict(n=n, mean=mean, stop=stop)
37
37
  super().__init__(defaults)
38
+ self.add_projected_keys("grad", "accumulator")
38
39
 
39
40
 
40
41
  @torch.no_grad
@@ -25,6 +25,7 @@ class Previous(TensorTransform):
25
25
  defaults = dict(n=n)
26
26
  super().__init__(defaults=defaults)
27
27
 
28
+ self.add_projected_keys("grad", "history")
28
29
 
29
30
  @torch.no_grad
30
31
  def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
@@ -42,6 +43,7 @@ class LastDifference(TensorTransform):
42
43
  """Outputs difference between past two updates."""
43
44
  def __init__(self,):
44
45
  super().__init__()
46
+ self.add_projected_keys("grad", "prev_tensors")
45
47
 
46
48
  @torch.no_grad
47
49
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
@@ -54,6 +56,7 @@ class LastGradDifference(Module):
54
56
  """Outputs difference between past two gradients."""
55
57
  def __init__(self):
56
58
  super().__init__()
59
+ self.add_projected_keys("grad", "prev_grad")
57
60
 
58
61
  @torch.no_grad
59
62
  def apply(self, objective):
@@ -84,6 +87,7 @@ class LastProduct(TensorTransform):
84
87
  """Outputs difference between past two updates."""
85
88
  def __init__(self):
86
89
  super().__init__()
90
+ self.add_projected_keys("grad", "prev")
87
91
 
88
92
  @torch.no_grad
89
93
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
@@ -97,6 +101,7 @@ class LastRatio(TensorTransform):
97
101
  def __init__(self, numerator: Literal['cur', 'prev'] = 'cur'):
98
102
  defaults = dict(numerator=numerator)
99
103
  super().__init__(defaults)
104
+ self.add_projected_keys("grad", "prev")
100
105
 
101
106
  @torch.no_grad
102
107
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
@@ -112,6 +117,7 @@ class LastAbsoluteRatio(TensorTransform):
112
117
  def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8):
113
118
  defaults = dict(numerator=numerator, eps=eps)
114
119
  super().__init__(defaults)
120
+ self.add_projected_keys("grad", "prev")
115
121
 
116
122
  @torch.no_grad
117
123
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
@@ -20,6 +20,8 @@ class Averaging(TensorTransform):
20
20
  defaults = dict(history_size=history_size)
21
21
  super().__init__(defaults=defaults)
22
22
 
23
+ self.add_projected_keys("grad", "history", "average")
24
+
23
25
  @torch.no_grad
24
26
  def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
25
27
  history_size = setting['history_size']
@@ -45,6 +47,8 @@ class WeightedAveraging(TensorTransform):
45
47
  defaults = dict(weights = tolist(weights))
46
48
  super().__init__(defaults=defaults)
47
49
 
50
+ self.add_projected_keys("grad", "history")
51
+
48
52
  @torch.no_grad
49
53
  def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
50
54
  weights = setting['weights']
@@ -79,6 +83,8 @@ class MedianAveraging(TensorTransform):
79
83
  defaults = dict(history_size = history_size)
80
84
  super().__init__(defaults=defaults)
81
85
 
86
+ self.add_projected_keys("grad", "history")
87
+
82
88
  @torch.no_grad
83
89
  def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
84
90
  history_size = setting['history_size']
@@ -6,7 +6,7 @@ import torch
6
6
 
7
7
  from ...core import TensorTransform
8
8
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
9
- from ..opt_utils import debias, ema_
9
+ from ..opt_utils import debias as _debias, ema_
10
10
 
11
11
 
12
12
  class EMA(TensorTransform):
@@ -15,20 +15,22 @@ class EMA(TensorTransform):
15
15
  Args:
16
16
  momentum (float, optional): momentum (beta). Defaults to 0.9.
17
17
  dampening (float, optional): momentum dampening. Defaults to 0.
18
- debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
18
+ debias (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
19
19
  lerp (bool, optional): whether to use linear interpolation. Defaults to True.
20
20
  ema_init (str, optional): initial values for the EMA, "zeros" or "update".
21
21
  target (Target, optional): target to apply EMA to. Defaults to 'update'.
22
22
  """
23
- def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros'):
24
- defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
23
+ def __init__(self, momentum:float=0.9, dampening:float=0, debias: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros'):
24
+ defaults = dict(momentum=momentum,dampening=dampening,debias=debias,lerp=lerp,ema_init=ema_init)
25
25
  super().__init__(defaults, uses_grad=False)
26
26
 
27
+ self.add_projected_keys("grad", "exp_avg")
28
+
27
29
  @torch.no_grad
28
30
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
29
31
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
30
32
 
31
- debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
33
+ debias, lerp, ema_init = itemgetter('debias','lerp','ema_init')(settings[0])
32
34
 
33
35
  exp_avg = unpack_states(states, tensors, 'exp_avg',
34
36
  init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
@@ -36,7 +38,7 @@ class EMA(TensorTransform):
36
38
 
37
39
  exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)
38
40
 
39
- if debiased: return debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
41
+ if debias: return _debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
40
42
  else: return exp_avg.clone() # this has exp_avg storage so needs to be cloned
41
43
 
42
44
 
@@ -47,14 +49,14 @@ class HeavyBall(EMA):
47
49
  Args:
48
50
  momentum (float, optional): momentum (beta). Defaults to 0.9.
49
51
  dampening (float, optional): momentum dampening. Defaults to 0.
50
- debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
52
+ debias (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
51
53
  lerp (bool, optional):
52
54
  whether to use linear interpolation, if True, this becomes exponential moving average. Defaults to False.
53
55
  ema_init (str, optional): initial values for the EMA, "zeros" or "update".
54
56
  target (Target, optional): target to apply EMA to. Defaults to 'update'.
55
57
  """
56
- def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update'):
57
- super().__init__(momentum=momentum, dampening=dampening, debiased=debiased, lerp=lerp, ema_init=ema_init)
58
+ def __init__(self, momentum:float=0.9, dampening:float=0, debias: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update'):
59
+ super().__init__(momentum=momentum, dampening=dampening, debias=debias, lerp=lerp, ema_init=ema_init)
58
60
 
59
61
  def nag_(
60
62
  tensors_: TensorList,
@@ -88,6 +90,8 @@ class NAG(TensorTransform):
88
90
  defaults = dict(momentum=momentum,dampening=dampening, lerp=lerp)
89
91
  super().__init__(defaults, uses_grad=False)
90
92
 
93
+ self.add_projected_keys("grad", "velocity")
94
+
91
95
  @torch.no_grad
92
96
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
93
97
  velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
@@ -13,7 +13,6 @@ from .binary import (
13
13
  CopySign,
14
14
  Div,
15
15
  GraftInputToOutput,
16
- GraftInputToOutput,
17
16
  GramSchimdt,
18
17
  Maximum,
19
18
  Minimum,
@@ -13,6 +13,7 @@ class AccumulateSum(TensorTransform):
13
13
  def __init__(self, decay: float = 0):
14
14
  defaults = dict(decay=decay)
15
15
  super().__init__(defaults)
16
+ self.add_projected_keys("grad", "sum")
16
17
 
17
18
  @torch.no_grad
18
19
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
@@ -30,6 +31,7 @@ class AccumulateMean(TensorTransform):
30
31
  def __init__(self, decay: float = 0):
31
32
  defaults = dict(decay=decay)
32
33
  super().__init__(defaults)
34
+ self.add_projected_keys("grad", "mean")
33
35
 
34
36
  @torch.no_grad
35
37
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
@@ -65,6 +67,7 @@ class AccumulateMaximum(TensorTransform):
65
67
  def __init__(self, decay: float = 0):
66
68
  defaults = dict(decay=decay)
67
69
  super().__init__(defaults)
70
+ self.add_projected_keys("grad", "maximum")
68
71
 
69
72
  @torch.no_grad
70
73
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
@@ -82,6 +85,7 @@ class AccumulateMinimum(TensorTransform):
82
85
  def __init__(self, decay: float = 0):
83
86
  defaults = dict(decay=decay)
84
87
  super().__init__(defaults)
88
+ self.add_projected_keys("grad", "minimum")
85
89
 
86
90
  @torch.no_grad
87
91
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):