torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
torchzero/core/module.py
CHANGED
|
@@ -1,510 +1,629 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
3
|
-
from
|
|
4
|
-
from collections.abc import Callable, Iterable, Sequence
|
|
5
|
-
from
|
|
6
|
-
from
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
self.model = model
|
|
55
|
-
"""
|
|
56
|
-
|
|
57
|
-
self.
|
|
58
|
-
"""
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
"""
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
self.
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
"""
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
self.
|
|
258
|
-
|
|
259
|
-
def
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
#
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
"""
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
#
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
def
|
|
411
|
-
"""
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
#
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
1
|
+
import warnings
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from collections import ChainMap, defaultdict
|
|
4
|
+
from collections.abc import Callable, Iterable, MutableMapping, Sequence
|
|
5
|
+
from operator import itemgetter
|
|
6
|
+
from typing import Any, final, overload
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ..utils import (
|
|
11
|
+
Init,
|
|
12
|
+
ListLike,
|
|
13
|
+
Params,
|
|
14
|
+
_make_param_groups,
|
|
15
|
+
get_state_vals,
|
|
16
|
+
)
|
|
17
|
+
from ..utils.python_tools import flatten
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _closure_backward(closure, params, retain_graph, create_graph):
|
|
21
|
+
with torch.enable_grad():
|
|
22
|
+
if not (retain_graph or create_graph):
|
|
23
|
+
return closure()
|
|
24
|
+
|
|
25
|
+
for p in params: p.grad = None
|
|
26
|
+
loss = closure(False)
|
|
27
|
+
grad = torch.autograd.grad(loss, params, retain_graph=retain_graph, create_graph=create_graph)
|
|
28
|
+
for p,g in zip(params,grad): p.grad = g
|
|
29
|
+
return loss
|
|
30
|
+
|
|
31
|
+
# region Vars
|
|
32
|
+
# ----------------------------------- vars ----------------------------------- #
|
|
33
|
+
class Vars:
|
|
34
|
+
"""
|
|
35
|
+
Holds the state and context passed between optimizer modules during a step.
|
|
36
|
+
|
|
37
|
+
This class acts as a mutable container for information relevant to the current
|
|
38
|
+
optimization step, such as parameters, gradients, loss, and the computed update.
|
|
39
|
+
Modules read from and write to this object to coordinate their actions.
|
|
40
|
+
"""
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
params: list[torch.Tensor],
|
|
44
|
+
closure: Callable | None,
|
|
45
|
+
model: torch.nn.Module | None,
|
|
46
|
+
current_step: int,
|
|
47
|
+
):
|
|
48
|
+
self.params: list[torch.Tensor] = params
|
|
49
|
+
"""List of all parameters with requires_grad = True."""
|
|
50
|
+
|
|
51
|
+
self.closure = closure
|
|
52
|
+
"""A closure that reevaluates the model and returns the loss, None if it wasn't specified"""
|
|
53
|
+
|
|
54
|
+
self.model = model
|
|
55
|
+
"""torch.nn.Module object of the model, None if it wasn't specified."""
|
|
56
|
+
|
|
57
|
+
self.current_step: int = current_step
|
|
58
|
+
"""global current step, starts at 0"""
|
|
59
|
+
|
|
60
|
+
self.update: list[torch.Tensor] | None = None
|
|
61
|
+
"""
|
|
62
|
+
current update, at the end this is subtracted from model parameters unless it is None.
|
|
63
|
+
|
|
64
|
+
If closure is None, this is initially set to cloned gradient. Otherwise this is set to None.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
self.grad: list[torch.Tensor] | None = None
|
|
68
|
+
"""gradient with current parameters. If closure is not None, this is set to None and can be calculated if needed."""
|
|
69
|
+
|
|
70
|
+
self.loss: torch.Tensor | Any | None = None
|
|
71
|
+
"""loss with current parameters."""
|
|
72
|
+
|
|
73
|
+
self.loss_approx: torch.Tensor | Any | None = None
|
|
74
|
+
"""loss at a point near current point. This can be useful as some modules only calculate loss at perturbed points,
|
|
75
|
+
whereas some other modules require loss strictly at current point."""
|
|
76
|
+
|
|
77
|
+
self.post_step_hooks: list[Callable[[Modular, Vars]]] = []
|
|
78
|
+
"""list of functions to be called after optimizer step.
|
|
79
|
+
The signature is:
|
|
80
|
+
|
|
81
|
+
.. code:: py
|
|
82
|
+
|
|
83
|
+
def hook(optimizer: Modular, vars: Vars): ...
|
|
84
|
+
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
self.is_last: bool = False
|
|
88
|
+
"""
|
|
89
|
+
Indicates that current module is either last or next-to-last before a learning rate module.
|
|
90
|
+
This is always False if current module has children or is a child.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
self.nested_is_last: bool = False
|
|
94
|
+
"""
|
|
95
|
+
Indicates that current module is either last or next-to-last before a learning rate module, for modules
|
|
96
|
+
that have children.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
self.last_module_lrs: list[float] | None = None
|
|
100
|
+
"""
|
|
101
|
+
List of per-parameter learning rates if current module is next-to-last before a
|
|
102
|
+
learning rate module, otherwise this is set to None. Ignore this unless you are manually applying
|
|
103
|
+
update to parameters.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
self.stop: bool = False
|
|
107
|
+
"""if True, all following modules will be skipped."""
|
|
108
|
+
|
|
109
|
+
self.skip_update: bool = False
|
|
110
|
+
"""if True, the parameters will not be updated"""
|
|
111
|
+
|
|
112
|
+
def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
|
|
113
|
+
"""Returns the loss at current parameters, computing it if it hasn't been computed already and assigning :code:`vars.loss`.
|
|
114
|
+
Do not call this at perturbed parameters. Backward always zeroes grads before recomputing."""
|
|
115
|
+
|
|
116
|
+
if self.loss is None:
|
|
117
|
+
if self.closure is None: raise RuntimeError("closure is None")
|
|
118
|
+
if backward:
|
|
119
|
+
with torch.enable_grad():
|
|
120
|
+
self.loss = self.loss_approx = _closure_backward(
|
|
121
|
+
closure=self.closure, params=self.params, retain_graph=retain_graph, create_graph=create_graph
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# initializing to zeros_like is equivalent to using zero_grad with set_to_none = False.
|
|
125
|
+
# it is technically a more correct approach for when some parameters conditionally receive gradients
|
|
126
|
+
# and in this case it shouldn't be slower.
|
|
127
|
+
self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
|
|
128
|
+
else:
|
|
129
|
+
self.loss = self.loss_approx = self.closure(False)
|
|
130
|
+
|
|
131
|
+
# if self.loss was not None, above branch wasn't executed because loss has already been evaluated, but without backward since self.grad is None.
|
|
132
|
+
# and now it is requested to be evaluated with backward.
|
|
133
|
+
if backward and self.grad is None:
|
|
134
|
+
warnings.warn('get_loss was called with backward=False, and then with backward=True so it had to be re-evaluated, so the closure was evaluated twice where it could have been evaluated once.')
|
|
135
|
+
if self.closure is None: raise RuntimeError("closure is None")
|
|
136
|
+
|
|
137
|
+
with torch.enable_grad():
|
|
138
|
+
self.loss = self.loss_approx = _closure_backward(
|
|
139
|
+
closure=self.closure, params=self.params, retain_graph=retain_graph, create_graph=create_graph
|
|
140
|
+
)
|
|
141
|
+
self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
|
|
142
|
+
return self.loss # type:ignore
|
|
143
|
+
|
|
144
|
+
def get_grad(self, retain_graph: bool | None = None, create_graph: bool = False) -> list[torch.Tensor]:
|
|
145
|
+
"""Returns the gradient at initial parameters, computing it if it hasn't been computed already and assigning
|
|
146
|
+
:code:`vars.grad` and potentially :code:`vars.loss`. Do not call this at perturbed parameters."""
|
|
147
|
+
if self.grad is None:
|
|
148
|
+
if self.closure is None: raise RuntimeError("closure is None")
|
|
149
|
+
self.get_loss(backward=True, retain_graph=retain_graph, create_graph=create_graph) # evaluate and set self.loss and self.grad
|
|
150
|
+
|
|
151
|
+
assert self.grad is not None
|
|
152
|
+
return self.grad
|
|
153
|
+
|
|
154
|
+
def get_update(self) -> list[torch.Tensor]:
|
|
155
|
+
"""Returns the update. If update is None, it is initialized by cloning the gradients and assigning to :code:`vars.update`.
|
|
156
|
+
Computing the gradients may assign :code:`vars.grad` and :code:`vars.loss` if they haven't been computed.
|
|
157
|
+
Do not call this at perturbed parameters."""
|
|
158
|
+
if self.update is None: self.update = [g.clone() for g in self.get_grad()]
|
|
159
|
+
return self.update
|
|
160
|
+
|
|
161
|
+
def clone(self, clone_update: bool):
|
|
162
|
+
"""Creates a shallow copy of the Vars object, update can optionally be deep-copied (via :code:`torch.clone`)."""
|
|
163
|
+
copy = Vars(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step)
|
|
164
|
+
|
|
165
|
+
if clone_update and self.update is not None:
|
|
166
|
+
copy.update = [u.clone() for u in self.update]
|
|
167
|
+
else:
|
|
168
|
+
copy.update = self.update
|
|
169
|
+
|
|
170
|
+
copy.grad = self.grad
|
|
171
|
+
copy.loss = self.loss
|
|
172
|
+
copy.loss_approx = self.loss_approx
|
|
173
|
+
copy.post_step_hooks = self.post_step_hooks
|
|
174
|
+
copy.stop = self.stop
|
|
175
|
+
copy.skip_update = self.skip_update
|
|
176
|
+
|
|
177
|
+
return copy
|
|
178
|
+
|
|
179
|
+
def update_attrs_from_clone_(self, vars: "Vars"):
|
|
180
|
+
"""Updates attributes of this `Vars` instance from a cloned instance.
|
|
181
|
+
Typically called after a child module has processed a cloned `Vars`
|
|
182
|
+
object. This propagates any newly computed loss or gradient values
|
|
183
|
+
from the child's context back to the parent `Vars` if the parent
|
|
184
|
+
didn't have them computed already.
|
|
185
|
+
"""
|
|
186
|
+
if self.loss is None: self.loss = vars.loss
|
|
187
|
+
if self.loss_approx is None: self.loss_approx = vars.loss_approx
|
|
188
|
+
if self.grad is None: self.grad = vars.grad
|
|
189
|
+
|
|
190
|
+
def zero_grad(self, set_to_none=True):
|
|
191
|
+
if set_to_none:
|
|
192
|
+
for p in self.params: p.grad = None
|
|
193
|
+
else:
|
|
194
|
+
grads = [p.grad for p in self.params if p.grad is not None]
|
|
195
|
+
if len(grads) != 0: torch._foreach_zero_(grads)
|
|
196
|
+
|
|
197
|
+
# endregion
|
|
198
|
+
|
|
199
|
+
# region Module
|
|
200
|
+
# ---------------------------------- module ---------------------------------- #
|
|
201
|
+
class Module(ABC):
|
|
202
|
+
"""Abstract base class for an optimizer modules.
|
|
203
|
+
|
|
204
|
+
Modules represent distinct steps or transformations within the optimization
|
|
205
|
+
process (e.g., momentum, line search, gradient accumulation).
|
|
206
|
+
|
|
207
|
+
A module does not store parameters, but it maintains per-parameter state and per-parameter settings
|
|
208
|
+
where tensors are used as keys (same as torch.optim.Optimizer state.)
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
defaults (dict[str, Any] | None):
|
|
212
|
+
a dict containing default values of optimization options (used when a parameter group doesn't specify them).
|
|
213
|
+
"""
|
|
214
|
+
def __init__(self, defaults: dict[str, Any] | None = None):
|
|
215
|
+
if defaults is None: defaults = {}
|
|
216
|
+
self.defaults: dict[str, Any] = defaults
|
|
217
|
+
|
|
218
|
+
# settings are stored like state in per-tensor defaultdict, with per-parameter overrides possible
|
|
219
|
+
# 0 - this module specific per-parameter setting overrides set via `set_param_groups` - highest priority
|
|
220
|
+
# 1 - global per-parameter setting overrides in param_groups passed to Modular - medium priority
|
|
221
|
+
# 2 - `defaults` - lowest priority
|
|
222
|
+
self.settings: defaultdict[torch.Tensor, ChainMap[str, Any]] = defaultdict(lambda: ChainMap({}, {}, self.defaults))
|
|
223
|
+
"""per-parameter settings."""
|
|
224
|
+
|
|
225
|
+
self.state: defaultdict[torch.Tensor, dict[str, Any]] = defaultdict(dict)
|
|
226
|
+
"""Per-parameter state (e.g., momentum buffers)."""
|
|
227
|
+
|
|
228
|
+
self.global_state: dict[str, Any] = {}
|
|
229
|
+
"""Global state for things that are not per-parameter."""
|
|
230
|
+
|
|
231
|
+
self.children: dict[str, Module] = {}
|
|
232
|
+
"""A dictionary of child modules."""
|
|
233
|
+
|
|
234
|
+
self._overridden_keys = set()
|
|
235
|
+
"""tracks keys overridden with `set_param_groups`, only used to not give a warning"""
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def set_param_groups(self, param_groups: Params):
|
|
239
|
+
"""Set custom parameter groups with per-parameter settings that this module will use."""
|
|
240
|
+
param_groups = _make_param_groups(param_groups, differentiable=False)
|
|
241
|
+
for group in param_groups:
|
|
242
|
+
settings = group.copy()
|
|
243
|
+
params = settings.pop('params')
|
|
244
|
+
if not settings: continue
|
|
245
|
+
self._overridden_keys.update(*settings.keys())
|
|
246
|
+
|
|
247
|
+
for param in params:
|
|
248
|
+
self.settings[param].maps[0].update(settings) # set module-specific per-parameter settings
|
|
249
|
+
return self
|
|
250
|
+
|
|
251
|
+
def set_child(self, key: str, module: "Module | Sequence[Module]"):
|
|
252
|
+
self.children[key] = maybe_chain(module)
|
|
253
|
+
|
|
254
|
+
def set_children_sequence(self, modules: "Iterable[Module | Sequence[Module]]", prefix = 'module_'):
|
|
255
|
+
modules = list(modules)
|
|
256
|
+
for i, m in enumerate(modules):
|
|
257
|
+
self.set_child(f'{prefix}{i}', maybe_chain(m))
|
|
258
|
+
|
|
259
|
+
def get_children_sequence(self, prefix = 'module_'):
|
|
260
|
+
return [self.children[f'{prefix}{i}'] for i in range(len(self.children)) if f'{prefix}{i}' in self.children]
|
|
261
|
+
|
|
262
|
+
def __repr__(self):
|
|
263
|
+
s = self.__class__.__name__
|
|
264
|
+
if self.children:
|
|
265
|
+
s = f'{s}('
|
|
266
|
+
for k,v in self.children.items():
|
|
267
|
+
s = f'{s}{k}={v}, '
|
|
268
|
+
s = f'{s[:-2]})'
|
|
269
|
+
return s
|
|
270
|
+
|
|
271
|
+
@overload
|
|
272
|
+
def get_settings(self, key: str, *,
|
|
273
|
+
params: Sequence[torch.Tensor], cls: type[ListLike] = list) -> ListLike: ...
|
|
274
|
+
@overload
|
|
275
|
+
def get_settings(self, key: list[str] | tuple[str,...], *,
|
|
276
|
+
params: Sequence[torch.Tensor], cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
277
|
+
@overload
|
|
278
|
+
def get_settings(self, key: str, key2: str, *keys: str,
|
|
279
|
+
params: Sequence[torch.Tensor], cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
280
|
+
|
|
281
|
+
def get_settings(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
|
|
282
|
+
params: Sequence[torch.Tensor], cls: type[ListLike] = list) -> ListLike | list[ListLike]:
|
|
283
|
+
# if isinstance(params, Vars): params = params.params
|
|
284
|
+
return get_state_vals(self.settings, params, key, key2, *keys, must_exist=True, cls=cls) # pyright:ignore[reportArgumentType]
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
@overload
|
|
288
|
+
def get_state(self, key: str, *,
|
|
289
|
+
params: Sequence[torch.Tensor], must_exist: bool = False, init: Init = torch.zeros_like,
|
|
290
|
+
cls: type[ListLike] = list) -> ListLike: ...
|
|
291
|
+
@overload
|
|
292
|
+
def get_state(self, key: list[str] | tuple[str,...], *,
|
|
293
|
+
params: Sequence[torch.Tensor], must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
294
|
+
cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
295
|
+
@overload
|
|
296
|
+
def get_state(self, key: str, key2: str, *keys: str,
|
|
297
|
+
params: Sequence[torch.Tensor], must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
298
|
+
cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
299
|
+
|
|
300
|
+
def get_state(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
|
|
301
|
+
params: Sequence[torch.Tensor], must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
302
|
+
cls: type[ListLike] = list) -> ListLike | list[ListLike]:
|
|
303
|
+
"""Returns values of per-parameter state for a given key.
|
|
304
|
+
If key doesn't exist, create it with inits.
|
|
305
|
+
|
|
306
|
+
This functions like `operator.itemgetter`, returning a single value if called with a single key,
|
|
307
|
+
or tuple of called with multiple keys.
|
|
308
|
+
|
|
309
|
+
If you want to force it to return a tuple even with a single key, pass a list/tuple of 1 or more keys.
|
|
310
|
+
|
|
311
|
+
.. code:: py
|
|
312
|
+
|
|
313
|
+
exp_avg = self.state_vals("exp_avg")
|
|
314
|
+
# returns cls (by default TensorList)
|
|
315
|
+
|
|
316
|
+
exp_avg, exp_avg_sq = self.state_vals("exp_avg", "exp_avg_sq")
|
|
317
|
+
# returns list of cls
|
|
318
|
+
|
|
319
|
+
exp_avg = self.state_vals(["exp_avg"])
|
|
320
|
+
# always returns a list of cls, even if got a single key
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
*keys (str):
|
|
325
|
+
the keys to look for in each parameters state.
|
|
326
|
+
if a single key is specified, this returns a single value or cls,
|
|
327
|
+
otherwise this returns a list of values or cls per each key.
|
|
328
|
+
params (Iterable[torch.Tensor]): parameters to return the states for.
|
|
329
|
+
must_exist (bool, optional):
|
|
330
|
+
If a key doesn't exist in state, if True, raises a KeyError, if False, creates the value
|
|
331
|
+
using `init` argument (default = False).
|
|
332
|
+
init (Init | Sequence[Init], optional):
|
|
333
|
+
how to initialize a key if it doesn't exist.
|
|
334
|
+
|
|
335
|
+
can be
|
|
336
|
+
- Callable like torch.zeros_like
|
|
337
|
+
- string - "param" or "grad" to use cloned params or cloned grads.
|
|
338
|
+
- anything else other than list/tuples will be used as-is, tensors will be cloned.
|
|
339
|
+
- list/tuple of values per each parameter, only if got a single key.
|
|
340
|
+
- list/tuple of values per each key, only if got multiple keys.
|
|
341
|
+
|
|
342
|
+
if multiple `keys` are specified, inits is per-key!
|
|
343
|
+
|
|
344
|
+
Defaults to torch.zeros_like.
|
|
345
|
+
cls (type[ListLike], optional):
|
|
346
|
+
MutableSequence class to return, this only has effect when state_keys is a list/tuple. Defaults to list.
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
- if state_keys has a single key and keys has a single key, return a single value.
|
|
350
|
+
- if state_keys has a single key and keys has multiple keys, return a list of values.
|
|
351
|
+
- if state_keys has multiple keys and keys has a single key, return cls.
|
|
352
|
+
- if state_keys has multiple keys and keys has multiple keys, return list of cls.
|
|
353
|
+
"""
|
|
354
|
+
# if isinstance(params, Vars): params = params.params
|
|
355
|
+
return get_state_vals(self.state, params, key, key2, *keys, must_exist=must_exist, init=init, cls=cls) # pyright:ignore[reportArgumentType]
|
|
356
|
+
|
|
357
|
+
# def first_setting(self, *keys:str, params:Sequence[torch.Tensor]):
|
|
358
|
+
# # if isinstance(params, Vars): params = params.params
|
|
359
|
+
# return itemgetter(*keys)(self.settings[params[0]])
|
|
360
|
+
|
|
361
|
+
def state_dict(self):
|
|
362
|
+
"""state dict"""
|
|
363
|
+
packed_state = {id(k):v for k,v in self.state.items()}
|
|
364
|
+
packed_settings = {id(k):v for k,v in self.settings.items()}
|
|
365
|
+
|
|
366
|
+
state_dict = {
|
|
367
|
+
"state": packed_state,
|
|
368
|
+
"settings":
|
|
369
|
+
{
|
|
370
|
+
"local": {k:v.maps[0] for k,v in packed_settings.items()},
|
|
371
|
+
"global": {k:v.maps[1] for k,v in packed_settings.items()},
|
|
372
|
+
"defaults": {k:v.maps[2] for k,v in packed_settings.items()},
|
|
373
|
+
},
|
|
374
|
+
"global_state": self.global_state,
|
|
375
|
+
"extra": self._extra_pack(),
|
|
376
|
+
"children": {k: v.state_dict() for k, v in self.children.items()}
|
|
377
|
+
}
|
|
378
|
+
return state_dict
|
|
379
|
+
|
|
380
|
+
def load_state_dict(self, state_dict: dict[str, Any], id_to_tensor: dict[int, torch.Tensor]):
|
|
381
|
+
# load state
|
|
382
|
+
state = state_dict['state']
|
|
383
|
+
self.state.clear()
|
|
384
|
+
self.state.update({id_to_tensor[k]:v for k,v in state.items()})
|
|
385
|
+
|
|
386
|
+
# load settings
|
|
387
|
+
settings = state_dict['settings']
|
|
388
|
+
self.settings.clear()
|
|
389
|
+
for k, v in settings['local'].items(): self.settings[id_to_tensor[k]].maps[0].update(v)
|
|
390
|
+
for k, v in settings['global'].items(): self.settings[id_to_tensor[k]].maps[1].update(v)
|
|
391
|
+
for k, v in settings['defaults'].items(): self.settings[id_to_tensor[k]].maps[2].update(v)
|
|
392
|
+
|
|
393
|
+
# load global state
|
|
394
|
+
self.global_state.clear()
|
|
395
|
+
self.global_state.update(state_dict['global_state'])
|
|
396
|
+
|
|
397
|
+
# children
|
|
398
|
+
for k, v in state_dict['children']:
|
|
399
|
+
if k in self.children: self.children[k].load_state_dict(v, id_to_tensor)
|
|
400
|
+
else: warnings.warn(f'State dict for {self} has child {k}, which is missing in {self}')
|
|
401
|
+
|
|
402
|
+
# extra info
|
|
403
|
+
self._extra_unpack(state_dict['extra'])
|
|
404
|
+
|
|
405
|
+
# ---------------------------- OVERRIDABLE METHODS --------------------------- #
|
|
406
|
+
@abstractmethod
|
|
407
|
+
def step(self, vars: Vars) -> Vars:
|
|
408
|
+
"""performs a step, returns new vars but may update them in-place."""
|
|
409
|
+
|
|
410
|
+
def reset(self):
|
|
411
|
+
"""Resets the internal state of the module (e.g. momentum)."""
|
|
412
|
+
# no complex logic is allowed there because this is overridden by many modules
|
|
413
|
+
# where super().reset() shouldn't be called
|
|
414
|
+
self.state.clear()
|
|
415
|
+
self.global_state.clear()
|
|
416
|
+
|
|
417
|
+
def _extra_pack(self):
|
|
418
|
+
return {}
|
|
419
|
+
|
|
420
|
+
def _extra_unpack(self, x):
|
|
421
|
+
pass
|
|
422
|
+
|
|
423
|
+
# endregion
|
|
424
|
+
|
|
425
|
+
Chainable = Module | Sequence[Module]
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def unroll_modules(*modules: Chainable) -> list[Module]:
|
|
429
|
+
unrolled = []
|
|
430
|
+
|
|
431
|
+
for m in modules:
|
|
432
|
+
if isinstance(m, Module):
|
|
433
|
+
unrolled.append(m)
|
|
434
|
+
unrolled.extend(unroll_modules(list(m.children.values())))
|
|
435
|
+
else:
|
|
436
|
+
unrolled.extend(unroll_modules(*m))
|
|
437
|
+
|
|
438
|
+
return unrolled
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
# region Modular
|
|
442
|
+
# ---------------------------------- Modular --------------------------------- #
|
|
443
|
+
# have to inherit from Modular to support lr schedulers
|
|
444
|
+
# although Accelerate doesn't work due to converting param_groups to a dict
|
|
445
|
+
class Modular(torch.optim.Optimizer):
|
|
446
|
+
"""Chains multiple modules into an optimizer.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
params (Params | torch.nn.Module): An iterable of parameters to optimize
|
|
450
|
+
(typically `model.parameters()`), an iterable of parameter group dicts,
|
|
451
|
+
or a `torch.nn.Module` instance.
|
|
452
|
+
*modules (Module): A sequence of `Module` instances that define the
|
|
453
|
+
optimization algorithm steps.
|
|
454
|
+
"""
|
|
455
|
+
# this is specifically for lr schedulers
|
|
456
|
+
param_groups: list[ChainMap[str, Any]] # pyright:ignore[reportIncompatibleVariableOverride]
|
|
457
|
+
|
|
458
|
+
def __init__(self, params: Params | torch.nn.Module, *modules: Module):
|
|
459
|
+
self.model: torch.nn.Module | None = None
|
|
460
|
+
"""The model whose parameters are being optimized, if a model instance was passed to `__init__`."""
|
|
461
|
+
if isinstance(params, torch.nn.Module):
|
|
462
|
+
self.model = params
|
|
463
|
+
params = params.parameters()
|
|
464
|
+
|
|
465
|
+
self.modules = modules
|
|
466
|
+
"""Top-level modules providedduring initialization."""
|
|
467
|
+
|
|
468
|
+
self.unrolled_modules = unroll_modules(self.modules)
|
|
469
|
+
"""A flattened list of all modules including all children."""
|
|
470
|
+
|
|
471
|
+
param_groups = _make_param_groups(params, differentiable=False)
|
|
472
|
+
self._per_parameter_global_settings: dict[torch.Tensor, list[MutableMapping[str, Any]]] = {}
|
|
473
|
+
|
|
474
|
+
# make sure there is no more than a single learning rate module
|
|
475
|
+
lr_modules = [m for m in self.unrolled_modules if 'lr' in m.defaults]
|
|
476
|
+
if len(lr_modules) > 1:
|
|
477
|
+
warnings.warn(f'multiple learning rate modules detected: {lr_modules}. This may lead to componding of learning rate multiplication with per-parameter learning rates and schedulers.')
|
|
478
|
+
|
|
479
|
+
# iterate over all per-parameter settings overrides and check if they are applied at most once
|
|
480
|
+
for group in param_groups:
|
|
481
|
+
for k in group:
|
|
482
|
+
if k in ('params', 'lr'): continue
|
|
483
|
+
modules_with_k = [m for m in self.unrolled_modules if k in m.defaults and k not in m._overridden_keys]
|
|
484
|
+
if len(modules_with_k) > 1:
|
|
485
|
+
warnings.warn(f'`params` has a `{k}` key, and multiple modules have that key: {modules_with_k}. If you intended to only set `{k}` to one of them, use `module.set_param_groups(params)`')
|
|
486
|
+
|
|
487
|
+
# defaults for schedulers
|
|
488
|
+
defaults = {}
|
|
489
|
+
for m in self.unrolled_modules: defaults.update(m.defaults)
|
|
490
|
+
super().__init__(param_groups, defaults=defaults)
|
|
491
|
+
|
|
492
|
+
# note - this is what super init does:
|
|
493
|
+
|
|
494
|
+
# self.defaults = defaults
|
|
495
|
+
# for param_group in param_groups:
|
|
496
|
+
# self.add_param_group(param_group)
|
|
497
|
+
|
|
498
|
+
self.current_step = 0
|
|
499
|
+
"""The global step counter for the optimizer."""
|
|
500
|
+
|
|
501
|
+
def add_param_group(self, param_group: dict[str, Any]):
|
|
502
|
+
proc_param_group = _make_param_groups([param_group], differentiable=False)[0]
|
|
503
|
+
self.param_groups.append(ChainMap(proc_param_group, self.defaults))
|
|
504
|
+
|
|
505
|
+
for p in proc_param_group['params']:
|
|
506
|
+
# updates global per-parameter setting overrides (medium priority)
|
|
507
|
+
self._per_parameter_global_settings[p] = [m.settings[p].maps[1] for m in self.unrolled_modules]
|
|
508
|
+
|
|
509
|
+
def state_dict(self):
|
|
510
|
+
all_params = [p for g in self.param_groups for p in g['params']]
|
|
511
|
+
id_to_idx = {id(p): i for i,p in enumerate(all_params)}
|
|
512
|
+
|
|
513
|
+
groups = []
|
|
514
|
+
for g in self.param_groups:
|
|
515
|
+
g = g.copy()
|
|
516
|
+
g['params'] = [id_to_idx[id(p)] for p in g['params']]
|
|
517
|
+
groups.append(g)
|
|
518
|
+
|
|
519
|
+
state_dict = {
|
|
520
|
+
"idx_to_id": {v:k for k,v in id_to_idx.items()},
|
|
521
|
+
"params": all_params,
|
|
522
|
+
"groups": groups,
|
|
523
|
+
"defaults": self.defaults,
|
|
524
|
+
"modules": {i: m.state_dict() for i, m in enumerate(self.unrolled_modules)}
|
|
525
|
+
}
|
|
526
|
+
return state_dict
|
|
527
|
+
|
|
528
|
+
def load_state_dict(self, state_dict: dict):
|
|
529
|
+
self.defaults.clear()
|
|
530
|
+
self.defaults.update(state_dict['defaults'])
|
|
531
|
+
|
|
532
|
+
idx_to_param = dict(enumerate(state_dict['params']))
|
|
533
|
+
groups = []
|
|
534
|
+
for g in state_dict['groups']:
|
|
535
|
+
g = g.copy()
|
|
536
|
+
g['params'] = [idx_to_param[p] for p in g['params']]
|
|
537
|
+
groups.append(g)
|
|
538
|
+
|
|
539
|
+
self.param_groups.clear()
|
|
540
|
+
for group in groups:
|
|
541
|
+
self.add_param_group(group)
|
|
542
|
+
|
|
543
|
+
id_to_tensor = {state_dict['idx_to_id'][i]: p for i,p in enumerate(state_dict['params'])}
|
|
544
|
+
for m, sd in zip(self.unrolled_modules, state_dict['modules'].values()):
|
|
545
|
+
m.load_state_dict(sd, id_to_tensor)
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
def step(self, closure=None): # pyright: ignore[reportIncompatibleMethodOverride]
|
|
549
|
+
# propagate global per-parameter setting overrides
|
|
550
|
+
for g in self.param_groups:
|
|
551
|
+
settings = dict(g.maps[0]) # ignore defaults
|
|
552
|
+
params = settings.pop('params')
|
|
553
|
+
if not settings: continue
|
|
554
|
+
|
|
555
|
+
for p in params:
|
|
556
|
+
if not p.requires_grad: continue
|
|
557
|
+
for map in self._per_parameter_global_settings[p]: map.update(settings)
|
|
558
|
+
|
|
559
|
+
# create vars
|
|
560
|
+
params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
|
|
561
|
+
vars = Vars(params=params, closure=closure, model=self.model, current_step=self.current_step)
|
|
562
|
+
|
|
563
|
+
# if closure is None, assume backward has been called and gather grads
|
|
564
|
+
if closure is None:
|
|
565
|
+
vars.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
566
|
+
|
|
567
|
+
last_module = self.modules[-1]
|
|
568
|
+
last_lr = last_module.defaults.get('lr', None)
|
|
569
|
+
n_modules = len(self.modules)
|
|
570
|
+
|
|
571
|
+
# step
|
|
572
|
+
for i, module in enumerate(self.modules):
|
|
573
|
+
if i!=0: vars = vars.clone(clone_update=False)
|
|
574
|
+
|
|
575
|
+
# last module, or next to last module before lr
|
|
576
|
+
if (i == n_modules - 1) or ((i == n_modules - 2) and (last_lr is not None)):
|
|
577
|
+
if module.children: vars.nested_is_last = True
|
|
578
|
+
else: vars.is_last = True
|
|
579
|
+
if last_lr is not None: vars.last_module_lrs = last_module.get_settings('lr', params=vars.params)
|
|
580
|
+
|
|
581
|
+
vars = module.step(vars)
|
|
582
|
+
if vars.stop: break
|
|
583
|
+
|
|
584
|
+
# apply update
|
|
585
|
+
if not vars.skip_update:
|
|
586
|
+
with torch.no_grad():
|
|
587
|
+
torch._foreach_sub_(params, vars.get_update())
|
|
588
|
+
|
|
589
|
+
for hook in vars.post_step_hooks:
|
|
590
|
+
hook(self, vars)
|
|
591
|
+
|
|
592
|
+
self.current_step += 1
|
|
593
|
+
return vars.loss if vars.loss is not None else vars.loss_approx
|
|
594
|
+
|
|
595
|
+
def __repr__(self):
|
|
596
|
+
return f'Modular({", ".join(str(m) for m in self.modules)})'
|
|
597
|
+
# endregion
|
|
598
|
+
|
|
599
|
+
# region Chain
|
|
600
|
+
# ----------------------------------- Chain ---------------------------------- #
|
|
601
|
+
class Chain(Module):
|
|
602
|
+
"""Chain of modules, mostly used internally"""
|
|
603
|
+
def __init__(self, *modules: Module | Iterable[Module]):
|
|
604
|
+
super().__init__()
|
|
605
|
+
flat_modules: list[Module] = flatten(modules)
|
|
606
|
+
for i, module in enumerate(flat_modules):
|
|
607
|
+
self.set_child(f'module_{i}', module)
|
|
608
|
+
|
|
609
|
+
def step(self, vars):
|
|
610
|
+
for i in range(len(self.children)):
|
|
611
|
+
vars = self.children[f'module_{i}'].step(vars)
|
|
612
|
+
if vars.stop: break
|
|
613
|
+
return vars
|
|
614
|
+
|
|
615
|
+
def __repr__(self):
|
|
616
|
+
s = self.__class__.__name__
|
|
617
|
+
if self.children:
|
|
618
|
+
if s == 'Chain': s = 'C' # to shorten it
|
|
619
|
+
s = f'{s}({", ".join(str(m) for m in self.children.values())}'
|
|
620
|
+
return s
|
|
621
|
+
|
|
622
|
+
def maybe_chain(*modules: Chainable) -> Module:
|
|
623
|
+
"""Returns a single module directly if only one is provided, otherwise wraps them in a :code:`Chain`."""
|
|
624
|
+
flat_modules: list[Module] = flatten(modules)
|
|
625
|
+
if len(flat_modules) == 1:
|
|
626
|
+
return flat_modules[0]
|
|
627
|
+
return Chain(*flat_modules)
|
|
628
|
+
# endregion
|
|
629
|
+
|