torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
import warnings
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Chainable, Module, Vars
|
|
9
|
+
from ...utils import vec_to_tensors
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _make_projected_closure(closure, vars: Vars, projection: "Projection",
|
|
13
|
+
params: list[torch.Tensor], projected_params: list[torch.Tensor]):
|
|
14
|
+
|
|
15
|
+
def projected_closure(backward=True):
|
|
16
|
+
unprojected_params = projection.unproject(projected_params, vars, current='params')
|
|
17
|
+
|
|
18
|
+
with torch.no_grad():
|
|
19
|
+
for p, new_p in zip(params, unprojected_params):
|
|
20
|
+
p.set_(new_p) # pyright: ignore[reportArgumentType]
|
|
21
|
+
|
|
22
|
+
if backward:
|
|
23
|
+
loss = closure()
|
|
24
|
+
grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
25
|
+
projected_grads = projection.project(grads, vars, current='grads')
|
|
26
|
+
for p, g in zip(projected_params, projected_grads):
|
|
27
|
+
p.grad = g
|
|
28
|
+
|
|
29
|
+
else:
|
|
30
|
+
loss = closure(False)
|
|
31
|
+
|
|
32
|
+
return loss
|
|
33
|
+
|
|
34
|
+
return projected_closure
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Projection(Module, ABC):
|
|
38
|
+
"""
|
|
39
|
+
Base class for projections.
|
|
40
|
+
This is an abstract class, to use it, subclass it and override `project` and `unproject`.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
modules (Chainable): modules that will be applied in the projected domain.
|
|
44
|
+
project_update (bool, optional): whether to project the update. Defaults to True.
|
|
45
|
+
project_params (bool, optional):
|
|
46
|
+
whether to project the params. This is necessary for modules that use closure. Defaults to False.
|
|
47
|
+
project_grad (bool, optional): whether to project the gradients (separately from update). Defaults to False.
|
|
48
|
+
defaults (dict[str, Any] | None, optional): dictionary with defaults. Defaults to None.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
modules: Chainable,
|
|
54
|
+
project_update=True,
|
|
55
|
+
project_params=False,
|
|
56
|
+
project_grad=False,
|
|
57
|
+
defaults: dict[str, Any] | None = None,
|
|
58
|
+
):
|
|
59
|
+
super().__init__(defaults)
|
|
60
|
+
self.set_child('modules', modules)
|
|
61
|
+
self.global_state['current_step'] = 0
|
|
62
|
+
self._project_update = project_update
|
|
63
|
+
self._project_params = project_params
|
|
64
|
+
self._project_grad = project_grad
|
|
65
|
+
self._projected_params = None
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
def project(self, tensors: list[torch.Tensor], vars: Vars, current: Literal['params', 'grads', 'update']) -> Iterable[torch.Tensor]:
|
|
69
|
+
"""projects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""
|
|
70
|
+
|
|
71
|
+
@abstractmethod
|
|
72
|
+
def unproject(self, tensors: list[torch.Tensor], vars: Vars, current: Literal['params', 'grads', 'update']) -> Iterable[torch.Tensor]:
|
|
73
|
+
"""unprojects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""
|
|
74
|
+
|
|
75
|
+
@torch.no_grad
|
|
76
|
+
def step(self, vars: Vars):
|
|
77
|
+
projected_vars = vars.clone(clone_update=False)
|
|
78
|
+
update_is_grad = False
|
|
79
|
+
|
|
80
|
+
# closure will calculate projected update and grad if needed
|
|
81
|
+
if self._project_params and vars.closure is not None:
|
|
82
|
+
if self._project_update and vars.update is not None: projected_vars.update = list(self.project(vars.update, vars=vars, current='update'))
|
|
83
|
+
else:
|
|
84
|
+
update_is_grad = True
|
|
85
|
+
if self._project_grad and vars.grad is not None: projected_vars.grad = list(self.project(vars.grad, vars=vars, current='grads'))
|
|
86
|
+
|
|
87
|
+
# project update and grad, unprojected attributes are deleted
|
|
88
|
+
else:
|
|
89
|
+
if self._project_update:
|
|
90
|
+
if vars.update is None:
|
|
91
|
+
# update is None, meaning it will be set to `grad`.
|
|
92
|
+
# we can project grad and use it for update
|
|
93
|
+
grad = vars.get_grad()
|
|
94
|
+
projected_vars.grad = list(self.project(grad, vars=vars, current='grads'))
|
|
95
|
+
if self._project_grad: projected_vars.update = [g.clone() for g in projected_vars.grad]
|
|
96
|
+
else: projected_vars.update = projected_vars.grad.copy() # don't clone because grad shouldn't be used
|
|
97
|
+
del vars.update
|
|
98
|
+
update_is_grad = True
|
|
99
|
+
|
|
100
|
+
else:
|
|
101
|
+
update = vars.get_update()
|
|
102
|
+
projected_vars.update = list(self.project(update, vars=vars, current='update'))
|
|
103
|
+
del update, vars.update
|
|
104
|
+
|
|
105
|
+
if self._project_grad and projected_vars.grad is None:
|
|
106
|
+
grad = vars.get_grad()
|
|
107
|
+
projected_vars.grad = list(self.project(grad, vars=vars, current='grads'))
|
|
108
|
+
|
|
109
|
+
original_params = None
|
|
110
|
+
if self._project_params:
|
|
111
|
+
original_params = [p.clone() for p in vars.params]
|
|
112
|
+
projected_params = self.project(vars.params, vars=vars, current='params')
|
|
113
|
+
|
|
114
|
+
else:
|
|
115
|
+
# make fake params for correct shapes and state storage
|
|
116
|
+
# they reuse update or grad storage for memory efficiency
|
|
117
|
+
projected_params = projected_vars.update if projected_vars.update is not None else projected_vars.grad
|
|
118
|
+
assert projected_params is not None
|
|
119
|
+
|
|
120
|
+
if self._projected_params is None:
|
|
121
|
+
# 1st step - create objects for projected_params. They have to remain the same python objects
|
|
122
|
+
# to support per-parameter states which are stored by ids.
|
|
123
|
+
self._projected_params = [p.view_as(p).requires_grad_() for p in projected_params]
|
|
124
|
+
else:
|
|
125
|
+
# set storage to new fake params while ID remains the same
|
|
126
|
+
for empty_p, new_p in zip(self._projected_params, projected_params):
|
|
127
|
+
empty_p.set_(new_p.view_as(new_p).requires_grad_()) # pyright: ignore[reportArgumentType]
|
|
128
|
+
|
|
129
|
+
# project closure
|
|
130
|
+
if self._project_params:
|
|
131
|
+
closure = vars.closure; params = vars.params
|
|
132
|
+
projected_vars.closure = _make_projected_closure(closure, vars=vars, projection=self, params=params,
|
|
133
|
+
projected_params=self._projected_params)
|
|
134
|
+
|
|
135
|
+
else:
|
|
136
|
+
projected_vars.closure = None
|
|
137
|
+
|
|
138
|
+
# step
|
|
139
|
+
projected_vars.params = self._projected_params
|
|
140
|
+
projected_vars = self.children['modules'].step(projected_vars)
|
|
141
|
+
|
|
142
|
+
# empty fake params storage
|
|
143
|
+
# this doesn't affect update/grad because it is a different python object, set_ changes storage on an object
|
|
144
|
+
if not self._project_params:
|
|
145
|
+
for p in self._projected_params:
|
|
146
|
+
p.set_(torch.empty(0, device=p.device, dtype=p.dtype)) # pyright: ignore[reportArgumentType]
|
|
147
|
+
|
|
148
|
+
# unproject
|
|
149
|
+
unprojected_vars = projected_vars.clone(clone_update=False)
|
|
150
|
+
unprojected_vars.closure = vars.closure
|
|
151
|
+
unprojected_vars.params = vars.params
|
|
152
|
+
if unprojected_vars.grad is None: unprojected_vars.grad = vars.grad
|
|
153
|
+
|
|
154
|
+
if self._project_update:
|
|
155
|
+
assert projected_vars.update is not None
|
|
156
|
+
unprojected_vars.update = list(self.unproject(projected_vars.update, vars=vars, current='grads' if update_is_grad else 'update'))
|
|
157
|
+
del projected_vars.update
|
|
158
|
+
|
|
159
|
+
# unprojecting grad doesn't make sense?
|
|
160
|
+
# if self._project_grad:
|
|
161
|
+
# assert projected_vars.grad is not None
|
|
162
|
+
# unprojected_vars.grad = list(self.unproject(projected_vars.grad, vars=vars))
|
|
163
|
+
|
|
164
|
+
del projected_vars
|
|
165
|
+
|
|
166
|
+
if original_params is not None:
|
|
167
|
+
for p, o in zip(unprojected_vars.params, original_params):
|
|
168
|
+
p.set_(o) # pyright: ignore[reportArgumentType]
|
|
169
|
+
|
|
170
|
+
return unprojected_vars
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class FlipConcatProjection(Projection):
|
|
175
|
+
"""
|
|
176
|
+
for testing
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
|
|
180
|
+
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
181
|
+
|
|
182
|
+
@torch.no_grad
|
|
183
|
+
def project(self, tensors, vars, current):
|
|
184
|
+
return [torch.cat([u.view(-1) for u in tensors], dim=-1).flip(0)]
|
|
185
|
+
|
|
186
|
+
@torch.no_grad
|
|
187
|
+
def unproject(self, tensors, vars, current):
|
|
188
|
+
return vec_to_tensors(vec=tensors[0].flip(0), reference=vars.params)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class NoopProjection(Projection):
|
|
192
|
+
"""an example projection which doesn't do anything for testing"""
|
|
193
|
+
|
|
194
|
+
def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
|
|
195
|
+
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
196
|
+
|
|
197
|
+
@torch.no_grad
|
|
198
|
+
def project(self, tensors, vars, current):
|
|
199
|
+
return tensors
|
|
200
|
+
|
|
201
|
+
@torch.no_grad
|
|
202
|
+
def unproject(self, tensors, vars, current):
|
|
203
|
+
return tensors
|
|
204
|
+
|
|
205
|
+
class MultipyProjection(Projection):
|
|
206
|
+
"""an example projection which multiplies everything by 2"""
|
|
207
|
+
|
|
208
|
+
def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
|
|
209
|
+
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
210
|
+
|
|
211
|
+
@torch.no_grad
|
|
212
|
+
def project(self, tensors, vars, current):
|
|
213
|
+
return torch._foreach_mul(tensors, 2)
|
|
214
|
+
|
|
215
|
+
@torch.no_grad
|
|
216
|
+
def unproject(self, tensors, vars, current):
|
|
217
|
+
return torch._foreach_div(tensors, 2)
|
|
218
|
+
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable
|
|
7
|
+
from ...utils import vec_to_tensors, TensorList
|
|
8
|
+
from ..optimizers.shampoo import _merge_small_dims
|
|
9
|
+
from .projection import Projection
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class VectorProjection(Projection):
|
|
13
|
+
"""
|
|
14
|
+
flattens and concatenates all parameters into a vector
|
|
15
|
+
"""
|
|
16
|
+
def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
|
|
17
|
+
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
18
|
+
|
|
19
|
+
@torch.no_grad
|
|
20
|
+
def project(self, tensors, vars, current):
|
|
21
|
+
return [torch.cat([u.view(-1) for u in tensors], dim=-1)]
|
|
22
|
+
|
|
23
|
+
@torch.no_grad
|
|
24
|
+
def unproject(self, tensors, vars, current):
|
|
25
|
+
return vec_to_tensors(vec=tensors[0], reference=vars.params)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TensorizeProjection(Projection):
|
|
30
|
+
"""flattens and concatenates all parameters into a vector and then reshapes it into a tensor"""
|
|
31
|
+
def __init__(self, modules: Chainable, max_side: int, project_update=True, project_params=False, project_grad=False):
|
|
32
|
+
defaults = dict(max_side=max_side)
|
|
33
|
+
super().__init__(modules, defaults=defaults, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
34
|
+
|
|
35
|
+
@torch.no_grad
|
|
36
|
+
def project(self, tensors, vars, current):
|
|
37
|
+
params = vars.params
|
|
38
|
+
max_side = self.settings[params[0]]['max_side']
|
|
39
|
+
num_elems = sum(t.numel() for t in tensors)
|
|
40
|
+
|
|
41
|
+
if num_elems < max_side:
|
|
42
|
+
self.global_state['remainder'] = 0
|
|
43
|
+
# return 1d
|
|
44
|
+
return [torch.cat([t.view(-1) for t in tensors])]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# determine appropriate shape to reshape into
|
|
48
|
+
ndims = math.ceil(math.log(num_elems, max_side)) # determine number of dims
|
|
49
|
+
dim_size = math.ceil(num_elems ** (1/ndims)) # average size of a dim with ndims
|
|
50
|
+
dims = [dim_size for _ in range(ndims)]
|
|
51
|
+
required_elems = math.prod(dims)
|
|
52
|
+
|
|
53
|
+
# add few extra zeros to vec to match a reshapable size
|
|
54
|
+
remainder = required_elems-num_elems
|
|
55
|
+
if remainder > 0: tensors = tensors + [torch.zeros(remainder, dtype=tensors[0].dtype, device=tensors[0].device)]
|
|
56
|
+
self.global_state['remainder'] = remainder
|
|
57
|
+
|
|
58
|
+
# flatten and reshape
|
|
59
|
+
vec = torch.cat([t.view(-1) for t in tensors])
|
|
60
|
+
return [vec.view(dims)]
|
|
61
|
+
|
|
62
|
+
@torch.no_grad
|
|
63
|
+
def unproject(self, tensors, vars, current):
|
|
64
|
+
remainder = self.global_state['remainder']
|
|
65
|
+
# warnings.warn(f'{tensors[0].shape = }')
|
|
66
|
+
vec = tensors[0].view(-1)
|
|
67
|
+
if remainder > 0: vec = vec[:-remainder]
|
|
68
|
+
return vec_to_tensors(vec, vars.params)
|
|
69
|
+
|
|
70
|
+
class BlockPartition(Projection):
|
|
71
|
+
"""splits parameters into blocks (for now flatttens them and chunks)"""
|
|
72
|
+
def __init__(self, modules: Chainable, max_size: int, batched: bool = False, project_update=True, project_params=False, project_grad=False):
|
|
73
|
+
defaults = dict(max_size=max_size, batched=batched)
|
|
74
|
+
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
|
|
75
|
+
|
|
76
|
+
@torch.no_grad
|
|
77
|
+
def project(self, tensors, vars, current):
|
|
78
|
+
partitioned = []
|
|
79
|
+
for p,t in zip(vars.params, tensors):
|
|
80
|
+
settings = self.settings[p]
|
|
81
|
+
max_size = settings['max_size']
|
|
82
|
+
n = t.numel()
|
|
83
|
+
if n <= max_size:
|
|
84
|
+
partitioned.append(t)
|
|
85
|
+
continue
|
|
86
|
+
|
|
87
|
+
t_flat = t.view(-1)
|
|
88
|
+
|
|
89
|
+
batched = settings['batched']
|
|
90
|
+
num_chunks = math.ceil(n / max_size)
|
|
91
|
+
|
|
92
|
+
if batched:
|
|
93
|
+
chunks_size = num_chunks * max_size
|
|
94
|
+
if num_chunks * max_size > n:
|
|
95
|
+
t_flat = torch.cat([t_flat, torch.zeros(n-chunks_size, dtype=t_flat.dtype, device=t_flat.device)])
|
|
96
|
+
partitioned.append(t_flat.view(num_chunks, -1))
|
|
97
|
+
|
|
98
|
+
else:
|
|
99
|
+
partitioned.extend(t_flat.chunk(num_chunks))
|
|
100
|
+
|
|
101
|
+
return partitioned
|
|
102
|
+
|
|
103
|
+
@torch.no_grad
|
|
104
|
+
def unproject(self, tensors, vars, current):
|
|
105
|
+
ti = iter(tensors)
|
|
106
|
+
unprojected = []
|
|
107
|
+
for p in vars.params:
|
|
108
|
+
settings = self.settings[p]
|
|
109
|
+
n = p.numel()
|
|
110
|
+
|
|
111
|
+
if settings['batched']:
|
|
112
|
+
unprojected.append(next(ti).view(-1)[:n].view_as(p))
|
|
113
|
+
|
|
114
|
+
else:
|
|
115
|
+
chunks = []
|
|
116
|
+
t_n = 0
|
|
117
|
+
while t_n < n:
|
|
118
|
+
t = next(ti)
|
|
119
|
+
chunks.append(t)
|
|
120
|
+
t_n += t.numel()
|
|
121
|
+
|
|
122
|
+
assert t_n == n
|
|
123
|
+
unprojected.append(torch.cat(chunks).view_as(p))
|
|
124
|
+
|
|
125
|
+
return unprojected
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class TensorNormsProjection(Projection):
|
|
129
|
+
def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
|
|
130
|
+
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
131
|
+
|
|
132
|
+
@torch.no_grad
|
|
133
|
+
def project(self, tensors, vars, current):
|
|
134
|
+
orig = self.get_state(f'{current}_orig', params=vars.params)
|
|
135
|
+
torch._foreach_copy_(orig, tensors)
|
|
136
|
+
|
|
137
|
+
norms = torch._foreach_norm(tensors)
|
|
138
|
+
self.get_state(f'{current}_orig_norms', params=vars.params, init=norms, cls=TensorList).set_(norms)
|
|
139
|
+
|
|
140
|
+
return [torch.stack(norms)]
|
|
141
|
+
|
|
142
|
+
@torch.no_grad
|
|
143
|
+
def unproject(self, tensors, vars, current):
|
|
144
|
+
orig = self.get_state(f'{current}_orig', params=vars.params)
|
|
145
|
+
orig_norms = torch.stack(self.get_state(f'{current}_orig_norms', params=vars.params))
|
|
146
|
+
target_norms = tensors[0]
|
|
147
|
+
|
|
148
|
+
orig_norms = torch.where(orig_norms == 0, 1, orig_norms)
|
|
149
|
+
|
|
150
|
+
torch._foreach_mul_(orig, (target_norms/orig_norms).detach().cpu().tolist())
|
|
151
|
+
return orig
|
|
@@ -1,4 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
# from .
|
|
1
|
+
from .cg import PolakRibiere, FletcherReeves, HestenesStiefel, DaiYuan, LiuStorey, ConjugateDescent, HagerZhang, HybridHS_DY
|
|
2
|
+
from .lbfgs import LBFGS
|
|
3
|
+
from .olbfgs import OnlineLBFGS
|
|
4
|
+
# from .experimental import ModularLBFGS
|
|
5
|
+
|
|
6
|
+
from .quasi_newton import BFGS, SR1, DFP, BroydenGood, BroydenBad, Greenstadt1, Greenstadt2, ColumnUpdatingMethod, ThomasOptimalMethod, PSB, Pearson2, SSVM
|
|
7
|
+
from .lsr1 import LSR1
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Chainable, Transform, apply
|
|
6
|
+
from ...utils import TensorList, as_tensorlist
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ConguateGradientBase(Transform, ABC):
|
|
10
|
+
"""all CGs are the same except beta calculation"""
|
|
11
|
+
def __init__(self, defaults = None, clip_beta: bool = False, reset_interval: int | None = None, inner: Chainable | None = None):
|
|
12
|
+
if defaults is None: defaults = {}
|
|
13
|
+
defaults['reset_interval'] = reset_interval
|
|
14
|
+
defaults['clip_beta'] = clip_beta
|
|
15
|
+
super().__init__(defaults, uses_grad=False)
|
|
16
|
+
|
|
17
|
+
if inner is not None:
|
|
18
|
+
self.set_child('inner', inner)
|
|
19
|
+
|
|
20
|
+
def initialize(self, p: TensorList, g: TensorList):
|
|
21
|
+
"""runs on first step when prev_grads and prev_dir are not available"""
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def get_beta(self, p: TensorList, g: TensorList, prev_g: TensorList, prev_d: TensorList) -> float | torch.Tensor:
|
|
25
|
+
"""returns beta"""
|
|
26
|
+
|
|
27
|
+
@torch.no_grad
|
|
28
|
+
def transform(self, tensors, params, grads, vars):
|
|
29
|
+
tensors = as_tensorlist(tensors)
|
|
30
|
+
params = as_tensorlist(params)
|
|
31
|
+
|
|
32
|
+
step = self.global_state.get('step', 0)
|
|
33
|
+
prev_dir, prev_grads = self.get_state('prev_dir', 'prev_grad', params=params, cls=TensorList)
|
|
34
|
+
|
|
35
|
+
# initialize on first step
|
|
36
|
+
if step == 0:
|
|
37
|
+
self.initialize(params, tensors)
|
|
38
|
+
prev_dir.copy_(tensors)
|
|
39
|
+
prev_grads.copy_(tensors)
|
|
40
|
+
self.global_state['step'] = step + 1
|
|
41
|
+
return tensors
|
|
42
|
+
|
|
43
|
+
# get beta
|
|
44
|
+
beta = self.get_beta(params, tensors, prev_grads, prev_dir)
|
|
45
|
+
if self.settings[params[0]]['clip_beta']: beta = max(0, beta) # pyright:ignore[reportArgumentType]
|
|
46
|
+
prev_grads.copy_(tensors)
|
|
47
|
+
|
|
48
|
+
# inner step
|
|
49
|
+
if 'inner' in self.children:
|
|
50
|
+
tensors = as_tensorlist(apply(self.children['inner'], tensors, params, grads, vars))
|
|
51
|
+
|
|
52
|
+
# calculate new direction with beta
|
|
53
|
+
dir = tensors.add_(prev_dir.mul_(beta))
|
|
54
|
+
prev_dir.copy_(dir)
|
|
55
|
+
|
|
56
|
+
# resetting
|
|
57
|
+
self.global_state['step'] = step + 1
|
|
58
|
+
reset_interval = self.settings[params[0]]['reset_interval']
|
|
59
|
+
if reset_interval is not None and (step+1) % reset_interval == 0:
|
|
60
|
+
self.reset()
|
|
61
|
+
|
|
62
|
+
return dir
|
|
63
|
+
|
|
64
|
+
# ------------------------------- Polak-Ribière ------------------------------ #
|
|
65
|
+
def polak_ribiere_beta(g: TensorList, prev_g: TensorList):
|
|
66
|
+
denom = prev_g.dot(prev_g)
|
|
67
|
+
if denom == 0: return 0
|
|
68
|
+
return g.dot(g - prev_g) / denom
|
|
69
|
+
|
|
70
|
+
class PolakRibiere(ConguateGradientBase):
|
|
71
|
+
"""Polak-Ribière-Polyak nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this."""
|
|
72
|
+
def __init__(self, clip_beta=True, reset_interval: int | None = None, inner: Chainable | None = None):
|
|
73
|
+
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
74
|
+
|
|
75
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
76
|
+
return polak_ribiere_beta(g, prev_g)
|
|
77
|
+
|
|
78
|
+
# ------------------------------ Fletcher–Reeves ----------------------------- #
|
|
79
|
+
def fletcher_reeves_beta(gg, prev_gg):
|
|
80
|
+
if prev_gg == 0: return 0
|
|
81
|
+
return gg / prev_gg
|
|
82
|
+
|
|
83
|
+
class FletcherReeves(ConguateGradientBase):
|
|
84
|
+
"""Fletcher–Reeves nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
85
|
+
def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
|
|
86
|
+
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
87
|
+
|
|
88
|
+
def initialize(self, p, g):
|
|
89
|
+
self.global_state['prev_gg'] = g.dot(g)
|
|
90
|
+
|
|
91
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
92
|
+
gg = g.dot(g)
|
|
93
|
+
beta = fletcher_reeves_beta(gg, self.global_state['prev_gg'])
|
|
94
|
+
self.global_state['prev_gg'] = gg
|
|
95
|
+
return beta
|
|
96
|
+
|
|
97
|
+
# ----------------------------- Hestenes–Stiefel ----------------------------- #
|
|
98
|
+
def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
99
|
+
grad_diff = g - prev_g
|
|
100
|
+
denom = prev_d.dot(grad_diff)
|
|
101
|
+
if denom == 0: return 0
|
|
102
|
+
return (g.dot(grad_diff) / denom).neg()
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class HestenesStiefel(ConguateGradientBase):
|
|
106
|
+
"""Hestenes–Stiefel nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
107
|
+
def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
|
|
108
|
+
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
109
|
+
|
|
110
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
111
|
+
return hestenes_stiefel_beta(g, prev_d, prev_g)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
# --------------------------------- Dai–Yuan --------------------------------- #
|
|
115
|
+
def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
116
|
+
denom = prev_d.dot(g - prev_g)
|
|
117
|
+
if denom == 0: return 0
|
|
118
|
+
return (g.dot(g) / denom).neg()
|
|
119
|
+
|
|
120
|
+
class DaiYuan(ConguateGradientBase):
|
|
121
|
+
"""Dai–Yuan nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
122
|
+
def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
|
|
123
|
+
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
124
|
+
|
|
125
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
126
|
+
return dai_yuan_beta(g, prev_d, prev_g)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
# -------------------------------- Liu-Storey -------------------------------- #
|
|
130
|
+
def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
|
|
131
|
+
denom = prev_g.dot(prev_d)
|
|
132
|
+
if denom == 0: return 0
|
|
133
|
+
return g.dot(g - prev_g) / denom
|
|
134
|
+
|
|
135
|
+
class LiuStorey(ConguateGradientBase):
|
|
136
|
+
"""Liu-Storey nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
137
|
+
def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
|
|
138
|
+
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
139
|
+
|
|
140
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
141
|
+
return liu_storey_beta(g, prev_d, prev_g)
|
|
142
|
+
|
|
143
|
+
# ----------------------------- Conjugate Descent ---------------------------- #
|
|
144
|
+
class ConjugateDescent(Transform):
|
|
145
|
+
"""Conjugate Descent (CD). This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
146
|
+
def __init__(self, inner: Chainable | None = None):
|
|
147
|
+
super().__init__(defaults={}, uses_grad=False)
|
|
148
|
+
|
|
149
|
+
if inner is not None:
|
|
150
|
+
self.set_child('inner', inner)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@torch.no_grad
|
|
154
|
+
def transform(self, tensors, params, grads, vars):
|
|
155
|
+
g = as_tensorlist(tensors)
|
|
156
|
+
|
|
157
|
+
prev_d = self.get_state('prev_dir', params=params, cls=TensorList, init = torch.zeros_like)
|
|
158
|
+
if 'denom' not in self.global_state:
|
|
159
|
+
self.global_state['denom'] = torch.tensor(0.).to(g[0])
|
|
160
|
+
|
|
161
|
+
prev_gd = self.global_state.get('prev_gd', 0)
|
|
162
|
+
if prev_gd == 0: beta = 0
|
|
163
|
+
else: beta = g.dot(g) / prev_gd
|
|
164
|
+
|
|
165
|
+
# inner step
|
|
166
|
+
if 'inner' in self.children:
|
|
167
|
+
g = as_tensorlist(apply(self.children['inner'], g, params, grads, vars))
|
|
168
|
+
|
|
169
|
+
dir = g.add_(prev_d.mul_(beta))
|
|
170
|
+
prev_d.copy_(dir)
|
|
171
|
+
self.global_state['prev_gd'] = g.dot(dir)
|
|
172
|
+
return dir
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
# -------------------------------- Hager-Zhang ------------------------------- #
|
|
176
|
+
def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
|
|
177
|
+
g_diff = g - prev_g
|
|
178
|
+
denom = prev_d.dot(g_diff)
|
|
179
|
+
if denom == 0: return 0
|
|
180
|
+
|
|
181
|
+
term1 = 1/denom
|
|
182
|
+
# term2
|
|
183
|
+
term2 = (g_diff - (2 * prev_d * (g_diff.pow(2).global_sum()/denom))).dot(g)
|
|
184
|
+
return (term1 * term2).neg()
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class HagerZhang(ConguateGradientBase):
|
|
188
|
+
"""Hager-Zhang nonlinear conjugate gradient method,
|
|
189
|
+
This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
190
|
+
def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
|
|
191
|
+
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
192
|
+
|
|
193
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
194
|
+
return hager_zhang_beta(g, prev_d, prev_g)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
# ----------------------------------- HS-DY ---------------------------------- #
|
|
198
|
+
def hs_dy_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
199
|
+
grad_diff = g - prev_g
|
|
200
|
+
denom = prev_d.dot(grad_diff)
|
|
201
|
+
if denom == 0: return 0
|
|
202
|
+
|
|
203
|
+
# Dai-Yuan
|
|
204
|
+
dy_beta = (g.dot(g) / denom).neg().clamp(min=0)
|
|
205
|
+
|
|
206
|
+
# Hestenes–Stiefel
|
|
207
|
+
hs_beta = (g.dot(grad_diff) / denom).neg().clamp(min=0)
|
|
208
|
+
|
|
209
|
+
return max(0, min(dy_beta, hs_beta)) # type:ignore
|
|
210
|
+
|
|
211
|
+
class HybridHS_DY(ConguateGradientBase):
|
|
212
|
+
"""HS-DY hybrid conjugate gradient method.
|
|
213
|
+
This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
214
|
+
def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
|
|
215
|
+
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
216
|
+
|
|
217
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
218
|
+
return hs_dy_beta(g, prev_d, prev_g)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .modular_lbfgs import ModularLBFGS
|