torchzero 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchzero/__init__.py +3 -1
- torchzero/_minimize/__init__.py +0 -0
- torchzero/_minimize/methods.py +95 -0
- torchzero/_minimize/minimize.py +518 -0
- torchzero/core/__init__.py +5 -5
- torchzero/core/chain.py +2 -1
- torchzero/core/functional.py +2 -1
- torchzero/core/module.py +75 -4
- torchzero/core/transform.py +6 -5
- torchzero/linalg/eigh.py +116 -68
- torchzero/linalg/linear_operator.py +1 -0
- torchzero/linalg/orthogonalize.py +60 -5
- torchzero/linalg/sketch.py +39 -0
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/adaptive/adagrad.py +2 -0
- torchzero/modules/adaptive/adam.py +5 -1
- torchzero/modules/adaptive/adan.py +3 -0
- torchzero/modules/adaptive/ggt.py +20 -18
- torchzero/modules/adaptive/lion.py +3 -1
- torchzero/modules/adaptive/mars.py +6 -5
- torchzero/modules/adaptive/msam.py +3 -0
- torchzero/modules/adaptive/rmsprop.py +2 -0
- torchzero/modules/adaptive/rprop.py +9 -7
- torchzero/modules/adaptive/shampoo.py +9 -1
- torchzero/modules/adaptive/soap.py +32 -29
- torchzero/modules/basis/__init__.py +2 -0
- torchzero/modules/basis/ggt_basis.py +199 -0
- torchzero/modules/basis/soap_basis.py +254 -0
- torchzero/modules/clipping/ema_clipping.py +32 -27
- torchzero/modules/clipping/growth_clipping.py +1 -0
- torchzero/modules/experimental/__init__.py +1 -6
- torchzero/modules/experimental/coordinate_momentum.py +2 -0
- torchzero/modules/experimental/cubic_adam.py +4 -0
- torchzero/modules/grad_approximation/__init__.py +3 -2
- torchzero/modules/least_squares/gn.py +6 -0
- torchzero/modules/misc/gradient_accumulation.py +1 -0
- torchzero/modules/misc/misc.py +6 -0
- torchzero/modules/momentum/averaging.py +6 -0
- torchzero/modules/momentum/momentum.py +4 -0
- torchzero/modules/ops/__init__.py +0 -1
- torchzero/modules/ops/accumulate.py +4 -0
- torchzero/modules/ops/higher_level.py +6 -1
- torchzero/modules/second_order/inm.py +4 -0
- torchzero/modules/second_order/newton.py +11 -3
- torchzero/modules/second_order/newton_cg.py +7 -3
- torchzero/modules/second_order/nystrom.py +14 -19
- torchzero/modules/second_order/rsn.py +37 -6
- torchzero/modules/trust_region/trust_region.py +2 -1
- torchzero/utils/benchmarks/logistic.py +33 -18
- torchzero/utils/params.py +13 -1
- torchzero/utils/tensorlist.py +2 -2
- {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/METADATA +1 -1
- {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/RECORD +56 -53
- torchzero/modules/experimental/adanystrom.py +0 -258
- torchzero/modules/experimental/common_directions_whiten.py +0 -142
- torchzero/modules/experimental/eigen_sr1.py +0 -182
- torchzero/modules/experimental/eigengrad.py +0 -207
- /torchzero/modules/{experimental → grad_approximation}/spsa1.py +0 -0
- {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/WHEEL +0 -0
- {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import TensorTransform, Chainable, Module
|
|
7
|
+
from ..adaptive import Adam
|
|
8
|
+
from ...utils import unpack_dicts, unpack_states, TensorList, NumberList, set_storage_
|
|
9
|
+
from ...modules.adaptive.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
10
|
+
from ...linalg import torch_linalg
|
|
11
|
+
from ..adaptive.soap import get_orthogonal_matrix, project, project_back, update_soap_covariances_
|
|
12
|
+
|
|
13
|
+
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
|
|
14
|
+
@torch.no_grad
|
|
15
|
+
def get_orthogonal_matrix_QR(grad_sqs: list[torch.Tensor], GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
|
|
16
|
+
"""
|
|
17
|
+
Computes the eigenbases of the preconditioner using one round of power iteration
|
|
18
|
+
followed by torch.linalg.qr decomposition.
|
|
19
|
+
"""
|
|
20
|
+
final = []
|
|
21
|
+
|
|
22
|
+
for ind, (M, O) in enumerate(zip(GG, Q_list)):
|
|
23
|
+
|
|
24
|
+
# skip 1d or large dims
|
|
25
|
+
if M is None:
|
|
26
|
+
final.append(None)
|
|
27
|
+
continue
|
|
28
|
+
|
|
29
|
+
assert O is not None
|
|
30
|
+
|
|
31
|
+
est_eig = torch.diagonal(O.T @ M @ O)
|
|
32
|
+
sort_idx = torch.argsort(est_eig, descending=True)
|
|
33
|
+
grad_sqs = [s.index_select(ind, sort_idx) for s in grad_sqs]
|
|
34
|
+
|
|
35
|
+
power_iter = M @ O[:, sort_idx]
|
|
36
|
+
Q, _ = torch_linalg.qr(power_iter.to(torch.float32), retry_float64=True)
|
|
37
|
+
Q = Q.to(power_iter.dtype)
|
|
38
|
+
|
|
39
|
+
final.append(Q)
|
|
40
|
+
|
|
41
|
+
return final, grad_sqs
|
|
42
|
+
|
|
43
|
+
class SOAPBasis(TensorTransform):
|
|
44
|
+
"""
|
|
45
|
+
Run another optimizer in Shampoo eigenbases.
|
|
46
|
+
|
|
47
|
+
Note:
|
|
48
|
+
the buffers of the ``basis_opt`` are re-projected whenever basis changes. The reprojection logic is not implemented on all modules. Some supported modules are:
|
|
49
|
+
|
|
50
|
+
``Adagrad``, ``Adam``, ``Adan``, ``Lion``, ``MARSCorrection``, ``MSAMMomentum``, ``RMSprop``, ``EMA``, ``HeavyBall``, ``NAG``, ``ClipNormByEMA``, ``ClipValueByEMA``, ``NormalizeByEMA``, ``ClipValueGrowth``, ``CoordinateMomentum``, ``CubicAdam``.
|
|
51
|
+
|
|
52
|
+
Additionally most modules with no internal buffers are supported, e.g. ``Cautious``, ``Sign``, ``ClipNorm``, ``Orthogonalize``, etc. However modules that use weight values, such as ``WeighDecay`` can't be supported, as weights can't be projected.
|
|
53
|
+
|
|
54
|
+
Also, if you say use ``EMA`` on output of ``Pow(2)``, the exponential average will be reprojected as gradient and not as squared gradients. Use modules like ``EMASquared``, ``SqrtEMASquared`` to get correct reprojections.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
basis_opt (Chainable): module or modules to run in Shampoo eigenbases.
|
|
58
|
+
shampoo_beta (float | None, optional):
|
|
59
|
+
beta for covariance matrices accumulators. Can be None, then it just sums them like Adagrad (which works worse). Defaults to 0.95.
|
|
60
|
+
precond_freq (int, optional): How often to update the preconditioner. Defaults to 10.
|
|
61
|
+
merge_small (bool, optional): Whether to merge small dims. Defaults to True.
|
|
62
|
+
max_dim (int, optional): Won't precondition dims larger than this. Defaults to 10_000.
|
|
63
|
+
precondition_1d (bool, optional):
|
|
64
|
+
Whether to precondition 1d params (SOAP paper sets this to False). Defaults to True.
|
|
65
|
+
inner (Chainable | None, optional):
|
|
66
|
+
output of this module is projected and ``basis_opt`` will run on it, but preconditioners are updated
|
|
67
|
+
from original gradients.
|
|
68
|
+
|
|
69
|
+
Examples:
|
|
70
|
+
SOAP with MARS and AMSGrad:
|
|
71
|
+
```python
|
|
72
|
+
opt = tz.Optimizer(
|
|
73
|
+
model.parameters(),
|
|
74
|
+
tz.m.SOAPBasis([tz.m.MARSCorrection(0.95), tz.m.Adam(0.95, 0.95, amsgrad=True)]),
|
|
75
|
+
tz.m.LR(1e-3)
|
|
76
|
+
)
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
LaProp in Shampoo eigenbases (SOLP):
|
|
80
|
+
```python
|
|
81
|
+
|
|
82
|
+
# we define LaProp through other modules, moved it out for brevity
|
|
83
|
+
laprop = (
|
|
84
|
+
tz.m.RMSprop(0.95),
|
|
85
|
+
tz.m.Debias(beta1=None, beta2=0.95),
|
|
86
|
+
tz.m.EMA(0.95),
|
|
87
|
+
tz.m.Debias(beta1=0.95, beta2=None),
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
opt = tz.Optimizer(
|
|
91
|
+
model.parameters(),
|
|
92
|
+
tz.m.SOAPBasis(laprop),
|
|
93
|
+
tz.m.LR(1e-3)
|
|
94
|
+
)
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
Lion in Shampoo eigenbases (works kinda well):
|
|
98
|
+
```python
|
|
99
|
+
|
|
100
|
+
opt = tz.Optimizer(
|
|
101
|
+
model.parameters(),
|
|
102
|
+
tz.m.SOAPBasis(tz.m.Lion()),
|
|
103
|
+
tz.m.LR(1e-3)
|
|
104
|
+
)
|
|
105
|
+
```
|
|
106
|
+
"""
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
basis_opt: Chainable,
|
|
110
|
+
shampoo_beta: float | None = 0.95,
|
|
111
|
+
precond_freq: int = 10,
|
|
112
|
+
merge_small: bool = True,
|
|
113
|
+
max_dim: int = 4096,
|
|
114
|
+
precondition_1d: bool = True,
|
|
115
|
+
inner: Chainable | None = None,
|
|
116
|
+
):
|
|
117
|
+
defaults = locals().copy()
|
|
118
|
+
del defaults['self'], defaults["inner"], defaults["basis_opt"]
|
|
119
|
+
|
|
120
|
+
super().__init__(defaults)
|
|
121
|
+
self.set_child("inner", inner)
|
|
122
|
+
self.set_child("basis_opt", basis_opt)
|
|
123
|
+
|
|
124
|
+
@torch.no_grad
|
|
125
|
+
def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
|
|
126
|
+
if setting["merge_small"]:
|
|
127
|
+
tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
|
|
128
|
+
|
|
129
|
+
state["exp_avg_proj"] = torch.zeros_like(tensor)
|
|
130
|
+
state["exp_avg_sq_proj"] = torch.zeros_like(tensor)
|
|
131
|
+
|
|
132
|
+
if tensor.ndim <= 1 and not setting["precondition_1d"]:
|
|
133
|
+
state['GG'] = []
|
|
134
|
+
|
|
135
|
+
else:
|
|
136
|
+
max_dim = setting["max_dim"]
|
|
137
|
+
state['GG'] = [
|
|
138
|
+
torch.zeros(s, s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
|
|
139
|
+
]
|
|
140
|
+
|
|
141
|
+
# either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
|
|
142
|
+
if len([i is not None for i in state['GG']]) == 0:
|
|
143
|
+
state['GG'] = None
|
|
144
|
+
|
|
145
|
+
# first covariance accumulation
|
|
146
|
+
if state['GG'] is not None:
|
|
147
|
+
update_soap_covariances_(tensor, GGs_=state['GG'], beta=setting["shampoo_beta"])
|
|
148
|
+
|
|
149
|
+
# get projection matrix with first gradients with eigh
|
|
150
|
+
try: state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
151
|
+
except torch.linalg.LinAlgError as e:
|
|
152
|
+
warnings.warn(f"torch.linalg.eigh raised an error when initializing SOAP Q matrices on 1st step, diagonal preconditioning will be used for this parameter. The error was:\n{e}")
|
|
153
|
+
state["GG"] = None
|
|
154
|
+
|
|
155
|
+
state['step'] = 0
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
# no update to avoid running merge_dims twice
|
|
159
|
+
|
|
160
|
+
@torch.no_grad
|
|
161
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
162
|
+
# note
|
|
163
|
+
# do not modify tensors in-place
|
|
164
|
+
# because they are used to update preconditioner at the end
|
|
165
|
+
|
|
166
|
+
steps = [s["step"] for s in states]
|
|
167
|
+
if any(s == 0 for s in steps):
|
|
168
|
+
# skip 1st update so to avoid using current gradient in the projection
|
|
169
|
+
# I scale it instead to avoid issues with further modules
|
|
170
|
+
for s in states: s["step"] += 1
|
|
171
|
+
return TensorList(tensors).clamp(-0.1, 0.1)
|
|
172
|
+
# return TensorList(tensors).zero_()
|
|
173
|
+
|
|
174
|
+
merged_updates = [] # for when exp_avg is maintained unprojected
|
|
175
|
+
merged_grads = [] # this doesn't go into preconditioner
|
|
176
|
+
projected = []
|
|
177
|
+
|
|
178
|
+
# -------------------------------- inner step -------------------------------- #
|
|
179
|
+
updates = tensors
|
|
180
|
+
has_inner = "inner" in self.children
|
|
181
|
+
if has_inner:
|
|
182
|
+
updates = self.inner_step_tensors("inner", updates, clone=True,
|
|
183
|
+
params=params, grads=grads, loss=loss)
|
|
184
|
+
|
|
185
|
+
# ---------------------------------- project --------------------------------- #
|
|
186
|
+
for grad, update, state, setting in zip(tensors, updates, states, settings):
|
|
187
|
+
if setting["merge_small"]:
|
|
188
|
+
update, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(update, setting["max_dim"])
|
|
189
|
+
if has_inner: # grad is a different tensor, merge it too
|
|
190
|
+
grad, _, _ = _merge_small_dims(grad, setting["max_dim"])
|
|
191
|
+
else: # in this case update is still just grad
|
|
192
|
+
grad = update
|
|
193
|
+
|
|
194
|
+
merged_updates.append(update)
|
|
195
|
+
merged_grads.append(grad)
|
|
196
|
+
|
|
197
|
+
if state['GG'] is not None:
|
|
198
|
+
update = project(update, state['Q'])
|
|
199
|
+
|
|
200
|
+
projected.append(update)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
# ------------------------ run opt in projected space ----------------------- #
|
|
204
|
+
dirs_proj = self.inner_step_tensors("basis_opt", tensors=projected, clone=True, grads=projected)
|
|
205
|
+
|
|
206
|
+
# ------------------------------- project back ------------------------------- #
|
|
207
|
+
dirs: list[torch.Tensor] = []
|
|
208
|
+
for dir, state, setting in zip(dirs_proj, states, settings):
|
|
209
|
+
if state['GG'] is not None:
|
|
210
|
+
dir = project_back(dir, state['Q'])
|
|
211
|
+
|
|
212
|
+
if setting["merge_small"]:
|
|
213
|
+
dir = _unmerge_small_dims(dir, state['flat_sizes'], state['sort_idxs'])
|
|
214
|
+
|
|
215
|
+
dirs.append(dir)
|
|
216
|
+
|
|
217
|
+
# -------------------------- update preconditioners -------------------------- #
|
|
218
|
+
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
219
|
+
|
|
220
|
+
grad_buffs = self.get_child_projected_buffers("basis_opt", "grad")
|
|
221
|
+
grad_sq_buffs = self.get_child_projected_buffers("basis_opt", ["grad_sq", "grad_cu"])
|
|
222
|
+
|
|
223
|
+
for i, (grad, state, setting) in enumerate(zip(merged_grads, states, settings)):
|
|
224
|
+
if state['GG'] is not None:
|
|
225
|
+
|
|
226
|
+
# lerp covariances
|
|
227
|
+
update_soap_covariances_(grad, state['GG'], beta=setting["shampoo_beta"])
|
|
228
|
+
|
|
229
|
+
# (state['step'] - 1) since we start updating on 2nd step
|
|
230
|
+
if (state['step'] - 1) % setting['precond_freq'] == 0:
|
|
231
|
+
g_buffs = [b[i] for b in grad_buffs]
|
|
232
|
+
g_sq_buffs = [b[i] for b in grad_sq_buffs]
|
|
233
|
+
|
|
234
|
+
# unproject grad buffers before updating
|
|
235
|
+
g_buffs_unproj = [project_back(buff, state["Q"]) for buff in g_buffs]
|
|
236
|
+
|
|
237
|
+
# update projection matrix and exp_avg_sq_proj
|
|
238
|
+
try:
|
|
239
|
+
state['Q'], g_sq_buffs_new = get_orthogonal_matrix_QR(
|
|
240
|
+
g_sq_buffs, state['GG'], state['Q'])
|
|
241
|
+
|
|
242
|
+
for b_old, b_new in zip(g_sq_buffs, g_sq_buffs_new):
|
|
243
|
+
set_storage_(b_old, b_new)
|
|
244
|
+
|
|
245
|
+
# re-project grad buffers
|
|
246
|
+
for b_proj, b_unproj in zip(g_buffs, g_buffs_unproj):
|
|
247
|
+
set_storage_(b_proj, project(b_unproj, state["Q"]))
|
|
248
|
+
|
|
249
|
+
except torch.linalg.LinAlgError:
|
|
250
|
+
pass
|
|
251
|
+
|
|
252
|
+
state["step"] += 1
|
|
253
|
+
|
|
254
|
+
return dirs
|
|
@@ -27,60 +27,67 @@ class ClipNormByEMA(TensorTransform):
|
|
|
27
27
|
self,
|
|
28
28
|
beta=0.99,
|
|
29
29
|
ord: Metrics = 2,
|
|
30
|
-
eps=1e-6,
|
|
31
30
|
tensorwise:bool=True,
|
|
32
31
|
max_ema_growth: float | None = 1.5,
|
|
33
|
-
|
|
32
|
+
init: float = 0.0,
|
|
33
|
+
min_norm: float = 1e-6,
|
|
34
|
+
|
|
34
35
|
inner: Chainable | None = None,
|
|
35
36
|
):
|
|
36
|
-
defaults = dict(beta=beta, ord=ord, tensorwise=tensorwise,
|
|
37
|
+
defaults = dict(beta=beta, ord=ord, tensorwise=tensorwise, init=init, min_norm=min_norm, max_ema_growth=max_ema_growth)
|
|
37
38
|
super().__init__(defaults, inner=inner)
|
|
39
|
+
self.add_projected_keys("grad", "exp_avg")
|
|
38
40
|
|
|
39
41
|
@torch.no_grad
|
|
40
42
|
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
41
43
|
tensors = TensorList(tensors)
|
|
42
|
-
|
|
44
|
+
eps = torch.finfo(tensors[0].dtype).tiny * 2
|
|
45
|
+
ord, tensorwise, init, max_ema_growth = itemgetter('ord', 'tensorwise', 'init', 'max_ema_growth')(settings[0])
|
|
43
46
|
|
|
44
|
-
beta,
|
|
47
|
+
beta, min_norm = unpack_dicts(settings, 'beta', 'min_norm', cls=NumberList)
|
|
45
48
|
|
|
46
|
-
|
|
49
|
+
exp_avg = unpack_states(states, tensors, 'exp_avg', init = lambda x: torch.full_like(x, init), cls=TensorList)
|
|
47
50
|
|
|
48
|
-
|
|
51
|
+
exp_avg.lerp_(tensors, 1-beta)
|
|
49
52
|
|
|
53
|
+
# ----------------------------- tensorwise update ---------------------------- #
|
|
50
54
|
if tensorwise:
|
|
51
|
-
|
|
55
|
+
tensors_norm = tensors.norm(ord)
|
|
56
|
+
ema_norm = exp_avg.metric(ord)
|
|
52
57
|
|
|
53
58
|
# clip ema norm growth
|
|
54
59
|
if max_ema_growth is not None:
|
|
55
60
|
prev_ema_norm = unpack_states(states, tensors, 'prev_ema_norm', init=ema_norm, cls=TensorList)
|
|
56
|
-
allowed_norm = (prev_ema_norm * max_ema_growth).clip(min=
|
|
61
|
+
allowed_norm = (prev_ema_norm * max_ema_growth).clip(min=min_norm)
|
|
62
|
+
|
|
57
63
|
ema_denom = (ema_norm / allowed_norm).clip(min=1)
|
|
58
|
-
|
|
64
|
+
exp_avg.div_(ema_denom)
|
|
59
65
|
ema_norm.div_(ema_denom)
|
|
66
|
+
|
|
60
67
|
prev_ema_norm.set_(ema_norm)
|
|
61
68
|
|
|
62
|
-
tensors_norm = tensors.norm(ord)
|
|
63
|
-
denom = tensors_norm / ema_norm.clip(min=eps)
|
|
64
|
-
if self.NORMALIZE: denom.clip_(min=eps)
|
|
65
|
-
else: denom.clip_(min=1)
|
|
66
69
|
|
|
70
|
+
# ------------------------------- global update ------------------------------ #
|
|
67
71
|
else:
|
|
68
|
-
|
|
72
|
+
tensors_norm = tensors.global_metric(ord)
|
|
73
|
+
ema_norm = exp_avg.global_metric(ord)
|
|
69
74
|
|
|
70
75
|
# clip ema norm growth
|
|
71
76
|
if max_ema_growth is not None:
|
|
72
77
|
prev_ema_norm = self.global_state.setdefault('prev_ema_norm', ema_norm)
|
|
73
|
-
allowed_norm = prev_ema_norm * max_ema_growth
|
|
78
|
+
allowed_norm = (prev_ema_norm * max_ema_growth).clip(min=min_norm[0])
|
|
79
|
+
|
|
74
80
|
if ema_norm > allowed_norm:
|
|
75
|
-
|
|
81
|
+
exp_avg.div_(ema_norm / allowed_norm)
|
|
76
82
|
ema_norm = allowed_norm
|
|
83
|
+
|
|
77
84
|
prev_ema_norm.set_(ema_norm)
|
|
78
85
|
|
|
79
|
-
tensors_norm = tensors.global_metric(ord)
|
|
80
|
-
denom = tensors_norm / ema_norm.clip(min=eps[0])
|
|
81
|
-
if self.NORMALIZE: denom.clip_(min=eps[0])
|
|
82
|
-
else: denom.clip_(min=1)
|
|
83
86
|
|
|
87
|
+
# ------------------- compute denominator to clip/normalize ------------------ #
|
|
88
|
+
denom = tensors_norm / ema_norm.clip(min=eps)
|
|
89
|
+
if self.NORMALIZE: denom.clip_(min=eps)
|
|
90
|
+
else: denom.clip_(min=1)
|
|
84
91
|
self.global_state['denom'] = denom
|
|
85
92
|
|
|
86
93
|
@torch.no_grad
|
|
@@ -121,7 +128,7 @@ class ClipValueByEMA(TensorTransform):
|
|
|
121
128
|
def __init__(
|
|
122
129
|
self,
|
|
123
130
|
beta=0.99,
|
|
124
|
-
init:
|
|
131
|
+
init: float = 0,
|
|
125
132
|
|
|
126
133
|
inner: Chainable | None = None,
|
|
127
134
|
exp_avg_tfm:Chainable | None=None,
|
|
@@ -130,12 +137,10 @@ class ClipValueByEMA(TensorTransform):
|
|
|
130
137
|
super().__init__(defaults, inner=inner)
|
|
131
138
|
|
|
132
139
|
self.set_child('exp_avg', exp_avg_tfm)
|
|
140
|
+
self.add_projected_keys("grad", "exp_avg")
|
|
133
141
|
|
|
134
142
|
def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
|
|
135
|
-
|
|
136
|
-
state["exp_avg"] = torch.zeros_like(tensor)
|
|
137
|
-
else:
|
|
138
|
-
state["exp_avg"] = tensor.abs()
|
|
143
|
+
state["exp_avg"] = tensor.abs() * setting["init"]
|
|
139
144
|
|
|
140
145
|
@torch.no_grad
|
|
141
146
|
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
@@ -153,4 +158,4 @@ class ClipValueByEMA(TensorTransform):
|
|
|
153
158
|
self.inner_step_tensors("exp_avg", exp_avg, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
|
|
154
159
|
|
|
155
160
|
tensors.clip_(-exp_avg, exp_avg)
|
|
156
|
-
return tensors
|
|
161
|
+
return tensors
|
|
@@ -30,6 +30,7 @@ class ClipValueGrowth(TensorTransform):
|
|
|
30
30
|
):
|
|
31
31
|
defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay)
|
|
32
32
|
super().__init__(defaults)
|
|
33
|
+
self.add_projected_keys("grad", "prev")
|
|
33
34
|
|
|
34
35
|
|
|
35
36
|
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
@@ -1,13 +1,9 @@
|
|
|
1
|
-
"""Those are various ideas of mine plus some other modules that I decided not to move to other sub-packages for whatever reason. This is generally less tested
|
|
2
|
-
from .adanystrom import AdaNystrom
|
|
3
|
-
from .common_directions_whiten import CommonDirectionsWhiten
|
|
1
|
+
"""Those are various ideas of mine plus some other modules that I decided not to move to other sub-packages for whatever reason. This is generally less tested."""
|
|
4
2
|
from .coordinate_momentum import CoordinateMomentum
|
|
5
3
|
from .cubic_adam import CubicAdam, SubspaceCubicAdam
|
|
6
4
|
from .curveball import CurveBall
|
|
7
|
-
from .eigen_sr1 import EigenSR1
|
|
8
5
|
|
|
9
6
|
# from dct import DCTProjection
|
|
10
|
-
from .eigengrad import Eigengrad
|
|
11
7
|
from .fft import FFTProjection
|
|
12
8
|
from .gradmin import GradMin
|
|
13
9
|
from .higher_order_newton import HigherOrderNewton
|
|
@@ -16,5 +12,4 @@ from .newton_solver import NewtonSolver
|
|
|
16
12
|
from .newtonnewton import NewtonNewton
|
|
17
13
|
from .reduce_outward_lr import ReduceOutwardLR
|
|
18
14
|
from .scipy_newton_cg import ScipyNewtonCG
|
|
19
|
-
from .spsa1 import SPSA1
|
|
20
15
|
from .structural_projections import BlockPartition, TensorizeProjection
|
|
@@ -29,6 +29,8 @@ class CoordinateMomentum(TensorTransform):
|
|
|
29
29
|
defaults = dict(p=p)
|
|
30
30
|
super().__init__(defaults)
|
|
31
31
|
|
|
32
|
+
self.add_projected_keys("grad", "velocity")
|
|
33
|
+
|
|
32
34
|
@torch.no_grad
|
|
33
35
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
34
36
|
p = NumberList(s['p'] for s in settings)
|
|
@@ -88,6 +88,10 @@ class CubicAdam(TensorTransform):
|
|
|
88
88
|
defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps,debiased=debiased,alpha=alpha,mode=mode)
|
|
89
89
|
super().__init__(defaults)
|
|
90
90
|
|
|
91
|
+
self.add_projected_keys("grad", "exp_avg")
|
|
92
|
+
self.add_projected_keys("grad_sq", "exp_avg_sq")
|
|
93
|
+
self.add_projected_keys("grad_cu", "exp_avg_cu")
|
|
94
|
+
|
|
91
95
|
@torch.no_grad
|
|
92
96
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
93
97
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from .grad_approximator import GradApproximator, GradTarget
|
|
2
1
|
from .fdm import FDM
|
|
3
|
-
from .rfdm import RandomizedFDM, MeZO, SPSA, RDSA, GaussianSmoothing
|
|
4
2
|
from .forward_gradient import ForwardGradient
|
|
3
|
+
from .grad_approximator import GradApproximator, GradTarget
|
|
4
|
+
from .rfdm import RDSA, SPSA, GaussianSmoothing, MeZO, RandomizedFDM
|
|
5
|
+
from .spsa1 import SPSA1
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
|
|
3
5
|
from ...core import Chainable, Transform
|
|
@@ -129,6 +131,10 @@ class GaussNewton(Transform):
|
|
|
129
131
|
r = objective.get_loss(backward=False) # n_residuals
|
|
130
132
|
assert isinstance(r, torch.Tensor)
|
|
131
133
|
|
|
134
|
+
if r.numel() == 1:
|
|
135
|
+
r = r.view(1,1)
|
|
136
|
+
warnings.warn("Gauss-newton got a single residual. Make sure objective function returns a vector of residuals.")
|
|
137
|
+
|
|
132
138
|
# set sum of squares scalar loss and it's gradient to objective
|
|
133
139
|
objective.loss = r.pow(2).sum()
|
|
134
140
|
|
torchzero/modules/misc/misc.py
CHANGED
|
@@ -25,6 +25,7 @@ class Previous(TensorTransform):
|
|
|
25
25
|
defaults = dict(n=n)
|
|
26
26
|
super().__init__(defaults=defaults)
|
|
27
27
|
|
|
28
|
+
self.add_projected_keys("grad", "history")
|
|
28
29
|
|
|
29
30
|
@torch.no_grad
|
|
30
31
|
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
@@ -42,6 +43,7 @@ class LastDifference(TensorTransform):
|
|
|
42
43
|
"""Outputs difference between past two updates."""
|
|
43
44
|
def __init__(self,):
|
|
44
45
|
super().__init__()
|
|
46
|
+
self.add_projected_keys("grad", "prev_tensors")
|
|
45
47
|
|
|
46
48
|
@torch.no_grad
|
|
47
49
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -54,6 +56,7 @@ class LastGradDifference(Module):
|
|
|
54
56
|
"""Outputs difference between past two gradients."""
|
|
55
57
|
def __init__(self):
|
|
56
58
|
super().__init__()
|
|
59
|
+
self.add_projected_keys("grad", "prev_grad")
|
|
57
60
|
|
|
58
61
|
@torch.no_grad
|
|
59
62
|
def apply(self, objective):
|
|
@@ -84,6 +87,7 @@ class LastProduct(TensorTransform):
|
|
|
84
87
|
"""Outputs difference between past two updates."""
|
|
85
88
|
def __init__(self):
|
|
86
89
|
super().__init__()
|
|
90
|
+
self.add_projected_keys("grad", "prev")
|
|
87
91
|
|
|
88
92
|
@torch.no_grad
|
|
89
93
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -97,6 +101,7 @@ class LastRatio(TensorTransform):
|
|
|
97
101
|
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur'):
|
|
98
102
|
defaults = dict(numerator=numerator)
|
|
99
103
|
super().__init__(defaults)
|
|
104
|
+
self.add_projected_keys("grad", "prev")
|
|
100
105
|
|
|
101
106
|
@torch.no_grad
|
|
102
107
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -112,6 +117,7 @@ class LastAbsoluteRatio(TensorTransform):
|
|
|
112
117
|
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8):
|
|
113
118
|
defaults = dict(numerator=numerator, eps=eps)
|
|
114
119
|
super().__init__(defaults)
|
|
120
|
+
self.add_projected_keys("grad", "prev")
|
|
115
121
|
|
|
116
122
|
@torch.no_grad
|
|
117
123
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -20,6 +20,8 @@ class Averaging(TensorTransform):
|
|
|
20
20
|
defaults = dict(history_size=history_size)
|
|
21
21
|
super().__init__(defaults=defaults)
|
|
22
22
|
|
|
23
|
+
self.add_projected_keys("grad", "history", "average")
|
|
24
|
+
|
|
23
25
|
@torch.no_grad
|
|
24
26
|
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
25
27
|
history_size = setting['history_size']
|
|
@@ -45,6 +47,8 @@ class WeightedAveraging(TensorTransform):
|
|
|
45
47
|
defaults = dict(weights = tolist(weights))
|
|
46
48
|
super().__init__(defaults=defaults)
|
|
47
49
|
|
|
50
|
+
self.add_projected_keys("grad", "history")
|
|
51
|
+
|
|
48
52
|
@torch.no_grad
|
|
49
53
|
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
50
54
|
weights = setting['weights']
|
|
@@ -79,6 +83,8 @@ class MedianAveraging(TensorTransform):
|
|
|
79
83
|
defaults = dict(history_size = history_size)
|
|
80
84
|
super().__init__(defaults=defaults)
|
|
81
85
|
|
|
86
|
+
self.add_projected_keys("grad", "history")
|
|
87
|
+
|
|
82
88
|
@torch.no_grad
|
|
83
89
|
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
84
90
|
history_size = setting['history_size']
|
|
@@ -24,6 +24,8 @@ class EMA(TensorTransform):
|
|
|
24
24
|
defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
|
|
25
25
|
super().__init__(defaults, uses_grad=False)
|
|
26
26
|
|
|
27
|
+
self.add_projected_keys("grad", "exp_avg")
|
|
28
|
+
|
|
27
29
|
@torch.no_grad
|
|
28
30
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
29
31
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
@@ -88,6 +90,8 @@ class NAG(TensorTransform):
|
|
|
88
90
|
defaults = dict(momentum=momentum,dampening=dampening, lerp=lerp)
|
|
89
91
|
super().__init__(defaults, uses_grad=False)
|
|
90
92
|
|
|
93
|
+
self.add_projected_keys("grad", "velocity")
|
|
94
|
+
|
|
91
95
|
@torch.no_grad
|
|
92
96
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
93
97
|
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
@@ -13,6 +13,7 @@ class AccumulateSum(TensorTransform):
|
|
|
13
13
|
def __init__(self, decay: float = 0):
|
|
14
14
|
defaults = dict(decay=decay)
|
|
15
15
|
super().__init__(defaults)
|
|
16
|
+
self.add_projected_keys("grad", "sum")
|
|
16
17
|
|
|
17
18
|
@torch.no_grad
|
|
18
19
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -30,6 +31,7 @@ class AccumulateMean(TensorTransform):
|
|
|
30
31
|
def __init__(self, decay: float = 0):
|
|
31
32
|
defaults = dict(decay=decay)
|
|
32
33
|
super().__init__(defaults)
|
|
34
|
+
self.add_projected_keys("grad", "mean")
|
|
33
35
|
|
|
34
36
|
@torch.no_grad
|
|
35
37
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -65,6 +67,7 @@ class AccumulateMaximum(TensorTransform):
|
|
|
65
67
|
def __init__(self, decay: float = 0):
|
|
66
68
|
defaults = dict(decay=decay)
|
|
67
69
|
super().__init__(defaults)
|
|
70
|
+
self.add_projected_keys("grad", "maximum")
|
|
68
71
|
|
|
69
72
|
@torch.no_grad
|
|
70
73
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -82,6 +85,7 @@ class AccumulateMinimum(TensorTransform):
|
|
|
82
85
|
def __init__(self, decay: float = 0):
|
|
83
86
|
defaults = dict(decay=decay)
|
|
84
87
|
super().__init__(defaults)
|
|
88
|
+
self.add_projected_keys("grad", "minimum")
|
|
85
89
|
|
|
86
90
|
@torch.no_grad
|
|
87
91
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -30,6 +30,7 @@ class EMASquared(TensorTransform):
|
|
|
30
30
|
def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2):
|
|
31
31
|
defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad)
|
|
32
32
|
super().__init__(defaults)
|
|
33
|
+
self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")
|
|
33
34
|
|
|
34
35
|
@torch.no_grad
|
|
35
36
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -57,7 +58,7 @@ class SqrtEMASquared(TensorTransform):
|
|
|
57
58
|
def __init__(self, beta:float=0.999, amsgrad=False, debiased: bool = False, pow:float=2,):
|
|
58
59
|
defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad,debiased=debiased)
|
|
59
60
|
super().__init__(defaults)
|
|
60
|
-
|
|
61
|
+
self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")
|
|
61
62
|
|
|
62
63
|
@torch.no_grad
|
|
63
64
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -141,6 +142,8 @@ class CenteredEMASquared(TensorTransform):
|
|
|
141
142
|
def __init__(self, beta: float = 0.99, amsgrad=False, pow:float=2):
|
|
142
143
|
defaults = dict(beta=beta, amsgrad=amsgrad, pow=pow)
|
|
143
144
|
super().__init__(defaults, uses_grad=False)
|
|
145
|
+
self.add_projected_keys("grad", "exp_avg")
|
|
146
|
+
self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")
|
|
144
147
|
|
|
145
148
|
@torch.no_grad
|
|
146
149
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -175,6 +178,8 @@ class CenteredSqrtEMASquared(TensorTransform):
|
|
|
175
178
|
def __init__(self, beta: float = 0.99, amsgrad=False, debiased: bool = False, pow:float=2):
|
|
176
179
|
defaults = dict(beta=beta, amsgrad=amsgrad, debiased=debiased, pow=pow)
|
|
177
180
|
super().__init__(defaults, uses_grad=False)
|
|
181
|
+
self.add_projected_keys("grad", "exp_avg")
|
|
182
|
+
self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")
|
|
178
183
|
|
|
179
184
|
@torch.no_grad
|
|
180
185
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -35,6 +35,8 @@ class ImprovedNewton(Transform):
|
|
|
35
35
|
self,
|
|
36
36
|
damping: float = 0,
|
|
37
37
|
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
38
|
+
eigv_tol: float | None = None,
|
|
39
|
+
truncate: int | None = None,
|
|
38
40
|
update_freq: int = 1,
|
|
39
41
|
precompute_inverse: bool | None = None,
|
|
40
42
|
use_lstsq: bool = False,
|
|
@@ -89,6 +91,8 @@ class ImprovedNewton(Transform):
|
|
|
89
91
|
state = self.global_state,
|
|
90
92
|
damping = fs["damping"],
|
|
91
93
|
eigval_fn = fs["eigval_fn"],
|
|
94
|
+
eigv_tol = fs["eigv_tol"],
|
|
95
|
+
truncate = fs["truncate"],
|
|
92
96
|
precompute_inverse = precompute_inverse,
|
|
93
97
|
use_lstsq = fs["use_lstsq"]
|
|
94
98
|
)
|