torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -1,29 +1,35 @@
|
|
|
1
1
|
import math
|
|
2
|
-
|
|
2
|
+
import warnings
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from collections
|
|
4
|
+
from collections import defaultdict, ChainMap
|
|
5
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
6
|
+
from functools import partial
|
|
5
7
|
from typing import Any, Literal
|
|
6
|
-
|
|
8
|
+
|
|
7
9
|
import torch
|
|
8
10
|
|
|
9
|
-
from ...core import Chainable, Module,
|
|
10
|
-
from ...utils import vec_to_tensors
|
|
11
|
+
from ...core import Chainable, Module, Var
|
|
12
|
+
from ...utils import vec_to_tensors, set_storage_
|
|
11
13
|
|
|
12
14
|
|
|
13
|
-
def _make_projected_closure(closure,
|
|
15
|
+
def _make_projected_closure(closure, project_fn, unproject_fn,
|
|
14
16
|
params: list[torch.Tensor], projected_params: list[torch.Tensor]):
|
|
15
|
-
|
|
16
17
|
def projected_closure(backward=True):
|
|
17
|
-
|
|
18
|
+
# unproject projected params
|
|
19
|
+
unprojected_params = unproject_fn(projected_tensors=projected_params, current='params')
|
|
18
20
|
|
|
21
|
+
# set actual model parameters to suggested parameters
|
|
19
22
|
with torch.no_grad():
|
|
20
23
|
for p, new_p in zip(params, unprojected_params):
|
|
21
24
|
p.set_(new_p) # pyright: ignore[reportArgumentType]
|
|
22
25
|
|
|
26
|
+
# evaluate closure with suggested parameters
|
|
23
27
|
if backward:
|
|
24
28
|
loss = closure()
|
|
25
29
|
grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
26
|
-
|
|
30
|
+
|
|
31
|
+
# project gradients on backward and set to projected parameter .grad attributes
|
|
32
|
+
projected_grads = project_fn(grads, current='grads')
|
|
27
33
|
for p, g in zip(projected_params, projected_grads):
|
|
28
34
|
p.grad = g
|
|
29
35
|
|
|
@@ -34,27 +40,44 @@ def _make_projected_closure(closure, vars: Vars, projection: "Projection",
|
|
|
34
40
|
|
|
35
41
|
return projected_closure
|
|
36
42
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
self
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
43
|
+
class _FakeProjectedClosure:
|
|
44
|
+
"""This is used when project_params is False. Then the closure is meant to only be used to evaluate the initial gradient.
|
|
45
|
+
It should just evaluate original closure, project the gradients, and set them to fake params.
|
|
46
|
+
|
|
47
|
+
I made it into a class so that it can know and raise when it evaluates closure more than once.
|
|
48
|
+
"""
|
|
49
|
+
__slots__ = ('closure', 'project_fn', 'params', 'fake_params', 'evaluated')
|
|
50
|
+
def __init__(self, closure, project_fn, params: list[torch.Tensor], fake_params: list[torch.Tensor]):
|
|
51
|
+
self.closure = closure
|
|
52
|
+
self.project_fn = project_fn
|
|
53
|
+
self.params = params
|
|
54
|
+
self.fake_params = fake_params
|
|
55
|
+
self.evaluated = False
|
|
56
|
+
|
|
57
|
+
def __call__(self, backward: bool = True):
|
|
58
|
+
if self.evaluated:
|
|
59
|
+
raise RuntimeError("set project_params to True if projected modules require closure.")
|
|
60
|
+
self.evaluated = True
|
|
61
|
+
|
|
62
|
+
# evaluate closure with suggested parameters
|
|
63
|
+
if backward:
|
|
64
|
+
|
|
65
|
+
loss = self.closure()
|
|
66
|
+
grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
|
|
67
|
+
|
|
68
|
+
# project gradients on backward and set to projected parameter .grad attributes
|
|
69
|
+
projected_grads = self.project_fn(grads, current='grads')
|
|
70
|
+
for p, g in zip(self.fake_params, projected_grads):
|
|
71
|
+
p.grad = g
|
|
72
|
+
|
|
73
|
+
else:
|
|
74
|
+
loss = self.closure(False)
|
|
75
|
+
|
|
76
|
+
return loss
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class ProjectionBase(Module, ABC):
|
|
58
81
|
"""
|
|
59
82
|
Base class for projections.
|
|
60
83
|
This is an abstract class, to use it, subclass it and override `project` and `unproject`.
|
|
@@ -84,57 +107,125 @@ class Projection(Module, ABC):
|
|
|
84
107
|
self._project_grad = project_grad
|
|
85
108
|
self._projected_params = None
|
|
86
109
|
|
|
110
|
+
self._states: dict[str, list[dict[str, Any]]] = {}
|
|
111
|
+
"""per-parameter states for each projection target"""
|
|
112
|
+
|
|
87
113
|
@abstractmethod
|
|
88
|
-
def project(
|
|
114
|
+
def project(
|
|
115
|
+
self,
|
|
116
|
+
tensors: list[torch.Tensor],
|
|
117
|
+
params: list[torch.Tensor],
|
|
118
|
+
grads: list[torch.Tensor] | None,
|
|
119
|
+
loss: torch.Tensor | None,
|
|
120
|
+
states: list[dict[str, Any]],
|
|
121
|
+
settings: list[ChainMap[str, Any]],
|
|
122
|
+
current: str,
|
|
123
|
+
) -> Iterable[torch.Tensor]:
|
|
89
124
|
"""projects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""
|
|
90
125
|
|
|
91
126
|
@abstractmethod
|
|
92
|
-
def unproject(
|
|
93
|
-
|
|
127
|
+
def unproject(
|
|
128
|
+
self,
|
|
129
|
+
projected_tensors: list[torch.Tensor],
|
|
130
|
+
params: list[torch.Tensor],
|
|
131
|
+
grads: list[torch.Tensor] | None,
|
|
132
|
+
loss: torch.Tensor | None,
|
|
133
|
+
states: list[dict[str, Any]],
|
|
134
|
+
settings: list[ChainMap[str, Any]],
|
|
135
|
+
current: str,
|
|
136
|
+
) -> Iterable[torch.Tensor]:
|
|
137
|
+
"""unprojects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
projected_tensors (list[torch.Tensor]): projected tensors to unproject.
|
|
141
|
+
params (list[torch.Tensor]): original, unprojected parameters.
|
|
142
|
+
grads (list[torch.Tensor] | None): original, unprojected gradients
|
|
143
|
+
loss (torch.Tensor | None): loss at initial point.
|
|
144
|
+
states (list[dict[str, Any]]): list of state dictionaries per each UNPROJECTED tensor.
|
|
145
|
+
settings (list[ChainMap[str, Any]]): list of setting dictionaries per each UNPROJECTED tensor.
|
|
146
|
+
current (str): string representing what is being unprojected, e.g. "params", "grads" or "update".
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Iterable[torch.Tensor]: unprojected tensors of the same shape as params
|
|
150
|
+
"""
|
|
94
151
|
|
|
95
152
|
@torch.no_grad
|
|
96
|
-
def step(self,
|
|
97
|
-
|
|
153
|
+
def step(self, var: Var):
|
|
154
|
+
params = var.params
|
|
155
|
+
settings = [self.settings[p] for p in params]
|
|
156
|
+
|
|
157
|
+
def _project(tensors: list[torch.Tensor], current: Literal['params', 'grads', 'update']):
|
|
158
|
+
states = self._states.setdefault(current, [{} for _ in params])
|
|
159
|
+
return list(self.project(
|
|
160
|
+
tensors=tensors,
|
|
161
|
+
params=params,
|
|
162
|
+
grads=var.grad,
|
|
163
|
+
loss=var.loss,
|
|
164
|
+
states=states,
|
|
165
|
+
settings=settings,
|
|
166
|
+
current=current,
|
|
167
|
+
))
|
|
168
|
+
|
|
169
|
+
projected_var = var.clone(clone_update=False)
|
|
170
|
+
|
|
171
|
+
closure = var.closure
|
|
172
|
+
|
|
173
|
+
# if this is True, update and grad were projected simultaneously under current="grads"
|
|
174
|
+
# so update will have to be unprojected with current="grads"
|
|
98
175
|
update_is_grad = False
|
|
99
176
|
|
|
100
|
-
# closure
|
|
101
|
-
|
|
102
|
-
|
|
177
|
+
# if closure is provided and project_params=True, make new closure that evaluates projected params
|
|
178
|
+
# that also means projected modules can evaluate grad/update at will, it shouldn't be computed here
|
|
179
|
+
# but if it has already been computed, it should be projected
|
|
180
|
+
if self._project_params and closure is not None:
|
|
181
|
+
|
|
182
|
+
if self._project_update and var.update is not None:
|
|
183
|
+
# project update only if it already exists
|
|
184
|
+
projected_var.update = _project(var.update, current='update')
|
|
185
|
+
|
|
103
186
|
else:
|
|
187
|
+
# update will be set to gradients on var.get_grad()
|
|
188
|
+
# therefore projection will happen with current="grads"
|
|
104
189
|
update_is_grad = True
|
|
105
|
-
if self._project_grad and vars.grad is not None: projected_vars.grad = list(self.project(vars.grad, vars=vars, current='grads'))
|
|
106
190
|
|
|
107
|
-
|
|
191
|
+
# project grad only if it already exists
|
|
192
|
+
if self._project_grad and var.grad is not None:
|
|
193
|
+
projected_var.grad = _project(var.grad, current='grads')
|
|
194
|
+
|
|
195
|
+
# otherwise update/grad needs to be calculated and projected here
|
|
108
196
|
else:
|
|
109
197
|
if self._project_update:
|
|
110
|
-
if
|
|
198
|
+
if var.update is None:
|
|
111
199
|
# update is None, meaning it will be set to `grad`.
|
|
112
200
|
# we can project grad and use it for update
|
|
113
|
-
grad =
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
del vars.update
|
|
201
|
+
grad = var.get_grad()
|
|
202
|
+
projected_var.grad = _project(grad, current='grads')
|
|
203
|
+
projected_var.update = [g.clone() for g in projected_var.grad]
|
|
204
|
+
del var.update
|
|
118
205
|
update_is_grad = True
|
|
119
206
|
|
|
120
207
|
else:
|
|
121
|
-
update
|
|
122
|
-
|
|
123
|
-
|
|
208
|
+
# update exists so it needs to be projected
|
|
209
|
+
update = var.get_update()
|
|
210
|
+
projected_var.update = _project(update, current='update')
|
|
211
|
+
del update, var.update
|
|
212
|
+
|
|
213
|
+
if self._project_grad and projected_var.grad is None:
|
|
214
|
+
# projected_vars.grad may have been projected simultaneously with update
|
|
215
|
+
# but if that didn't happen, it is projected here
|
|
216
|
+
grad = var.get_grad()
|
|
217
|
+
projected_var.grad = _project(grad, current='grads')
|
|
124
218
|
|
|
125
|
-
if self._project_grad and projected_vars.grad is None:
|
|
126
|
-
grad = vars.get_grad()
|
|
127
|
-
projected_vars.grad = list(self.project(grad, vars=vars, current='grads'))
|
|
128
219
|
|
|
129
220
|
original_params = None
|
|
130
221
|
if self._project_params:
|
|
131
|
-
original_params = [p.clone() for p in
|
|
132
|
-
projected_params =
|
|
222
|
+
original_params = [p.clone() for p in var.params]
|
|
223
|
+
projected_params = _project(var.params, current='params')
|
|
133
224
|
|
|
134
225
|
else:
|
|
135
226
|
# make fake params for correct shapes and state storage
|
|
136
227
|
# they reuse update or grad storage for memory efficiency
|
|
137
|
-
projected_params =
|
|
228
|
+
projected_params = projected_var.update if projected_var.update is not None else projected_var.grad
|
|
138
229
|
assert projected_params is not None
|
|
139
230
|
|
|
140
231
|
if self._projected_params is None:
|
|
@@ -146,99 +237,102 @@ class Projection(Module, ABC):
|
|
|
146
237
|
for empty_p, new_p in zip(self._projected_params, projected_params):
|
|
147
238
|
empty_p.set_(new_p.view_as(new_p).requires_grad_()) # pyright: ignore[reportArgumentType]
|
|
148
239
|
|
|
240
|
+
projected_params = self._projected_params
|
|
241
|
+
# projected_settings = [self.settings[p] for p in projected_params]
|
|
242
|
+
|
|
243
|
+
def _unproject(projected_tensors: list[torch.Tensor], current: Literal['params', 'grads', 'update']):
|
|
244
|
+
states = self._states.setdefault(current, [{} for _ in params])
|
|
245
|
+
return list(self.unproject(
|
|
246
|
+
projected_tensors=projected_tensors,
|
|
247
|
+
params=params,
|
|
248
|
+
grads=var.grad,
|
|
249
|
+
loss=var.loss,
|
|
250
|
+
states=states,
|
|
251
|
+
settings=settings,
|
|
252
|
+
current=current,
|
|
253
|
+
))
|
|
254
|
+
|
|
149
255
|
# project closure
|
|
150
256
|
if self._project_params:
|
|
151
|
-
closure =
|
|
152
|
-
|
|
153
|
-
|
|
257
|
+
projected_var.closure = _make_projected_closure(closure, project_fn=_project, unproject_fn=_unproject,
|
|
258
|
+
params=params, projected_params=projected_params)
|
|
259
|
+
|
|
260
|
+
elif closure is not None:
|
|
261
|
+
projected_var.closure = _FakeProjectedClosure(closure, project_fn=_project,
|
|
262
|
+
params=params, fake_params=projected_params)
|
|
154
263
|
|
|
155
264
|
else:
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
# step
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
_projected_get_grad_override,
|
|
162
|
-
projection=self,
|
|
163
|
-
unprojected_vars=vars,
|
|
164
|
-
self=projected_vars,
|
|
165
|
-
)
|
|
166
|
-
projected_vars = self.children['modules'].step(projected_vars)
|
|
265
|
+
projected_var.closure = None
|
|
266
|
+
|
|
267
|
+
# ----------------------------------- step ----------------------------------- #
|
|
268
|
+
projected_var.params = projected_params
|
|
269
|
+
projected_var = self.children['modules'].step(projected_var)
|
|
167
270
|
|
|
168
271
|
# empty fake params storage
|
|
169
272
|
# this doesn't affect update/grad because it is a different python object, set_ changes storage on an object
|
|
170
273
|
if not self._project_params:
|
|
171
274
|
for p in self._projected_params:
|
|
172
|
-
p
|
|
275
|
+
set_storage_(p, torch.empty(0, device=p.device, dtype=p.dtype))
|
|
173
276
|
|
|
174
|
-
# unproject
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
277
|
+
# --------------------------------- unproject -------------------------------- #
|
|
278
|
+
unprojected_var = projected_var.clone(clone_update=False)
|
|
279
|
+
unprojected_var.closure = var.closure
|
|
280
|
+
unprojected_var.params = var.params
|
|
281
|
+
unprojected_var.grad = var.grad
|
|
179
282
|
|
|
180
283
|
if self._project_update:
|
|
181
|
-
assert
|
|
182
|
-
|
|
183
|
-
del
|
|
284
|
+
assert projected_var.update is not None
|
|
285
|
+
unprojected_var.update = _unproject(projected_var.update, current='grads' if update_is_grad else 'update')
|
|
286
|
+
del projected_var.update
|
|
184
287
|
|
|
185
|
-
|
|
186
|
-
# if self._project_grad:
|
|
187
|
-
# assert projected_vars.grad is not None
|
|
188
|
-
# unprojected_vars.grad = list(self.unproject(projected_vars.grad, vars=vars))
|
|
189
|
-
|
|
190
|
-
del projected_vars
|
|
288
|
+
del projected_var
|
|
191
289
|
|
|
290
|
+
# original params are stored if params are projected
|
|
192
291
|
if original_params is not None:
|
|
193
|
-
for p, o in zip(
|
|
292
|
+
for p, o in zip(unprojected_var.params, original_params):
|
|
194
293
|
p.set_(o) # pyright: ignore[reportArgumentType]
|
|
195
294
|
|
|
196
|
-
return
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
class FlipConcatProjection(Projection):
|
|
201
|
-
"""
|
|
202
|
-
for testing
|
|
203
|
-
"""
|
|
204
|
-
|
|
205
|
-
def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
|
|
206
|
-
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
207
|
-
|
|
208
|
-
@torch.no_grad
|
|
209
|
-
def project(self, tensors, vars, current):
|
|
210
|
-
return [torch.cat([u.view(-1) for u in tensors], dim=-1).flip(0)]
|
|
211
|
-
|
|
212
|
-
@torch.no_grad
|
|
213
|
-
def unproject(self, tensors, vars, current):
|
|
214
|
-
return vec_to_tensors(vec=tensors[0].flip(0), reference=vars.params)
|
|
295
|
+
return unprojected_var
|
|
215
296
|
|
|
216
297
|
|
|
217
|
-
class NoopProjection(Projection):
|
|
218
|
-
"""an example projection which doesn't do anything for testing"""
|
|
219
298
|
|
|
220
|
-
|
|
299
|
+
# basic examples
|
|
300
|
+
class VectorProjection(ProjectionBase):
|
|
301
|
+
"""projection that concatenates all parameters into a vector"""
|
|
302
|
+
def __init__(
|
|
303
|
+
self,
|
|
304
|
+
modules: Chainable,
|
|
305
|
+
project_update=True,
|
|
306
|
+
project_params=True,
|
|
307
|
+
project_grad=True,
|
|
308
|
+
):
|
|
221
309
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
222
310
|
|
|
223
311
|
@torch.no_grad
|
|
224
|
-
def project(self, tensors,
|
|
225
|
-
return tensors
|
|
312
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
313
|
+
return [torch.cat([t.ravel() for t in tensors])]
|
|
226
314
|
|
|
227
315
|
@torch.no_grad
|
|
228
|
-
def unproject(self,
|
|
229
|
-
return
|
|
316
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
317
|
+
return vec_to_tensors(vec=projected_tensors[0], reference=params)
|
|
230
318
|
|
|
231
|
-
class MultipyProjection(Projection):
|
|
232
|
-
"""an example projection which multiplies everything by 2"""
|
|
233
319
|
|
|
234
|
-
|
|
320
|
+
class ScalarProjection(ProjectionBase):
|
|
321
|
+
"""projetion that splits all parameters into individual scalars"""
|
|
322
|
+
def __init__(
|
|
323
|
+
self,
|
|
324
|
+
modules: Chainable,
|
|
325
|
+
project_update=True,
|
|
326
|
+
project_params=True,
|
|
327
|
+
project_grad=True,
|
|
328
|
+
):
|
|
235
329
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
236
330
|
|
|
237
331
|
@torch.no_grad
|
|
238
|
-
def project(self, tensors,
|
|
239
|
-
return
|
|
332
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
333
|
+
return [s for t in tensors for s in t.ravel().unbind(0)]
|
|
240
334
|
|
|
241
335
|
@torch.no_grad
|
|
242
|
-
def unproject(self,
|
|
243
|
-
return torch.
|
|
336
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
337
|
+
return vec_to_tensors(vec=torch.stack(projected_tensors), reference=params)
|
|
244
338
|
|
|
@@ -1,7 +1,46 @@
|
|
|
1
|
-
from .cg import
|
|
1
|
+
from .cg import (
|
|
2
|
+
ConjugateDescent,
|
|
3
|
+
DaiYuan,
|
|
4
|
+
FletcherReeves,
|
|
5
|
+
HagerZhang,
|
|
6
|
+
HestenesStiefel,
|
|
7
|
+
HybridHS_DY,
|
|
8
|
+
LiuStorey,
|
|
9
|
+
PolakRibiere,
|
|
10
|
+
ProjectedGradientMethod,
|
|
11
|
+
)
|
|
12
|
+
from .diagonal_quasi_newton import (
|
|
13
|
+
DNRTR,
|
|
14
|
+
DiagonalBFGS,
|
|
15
|
+
DiagonalQuasiCauchi,
|
|
16
|
+
DiagonalSR1,
|
|
17
|
+
DiagonalWeightedQuasiCauchi,
|
|
18
|
+
NewDQN,
|
|
19
|
+
)
|
|
2
20
|
from .lbfgs import LBFGS
|
|
3
|
-
from .
|
|
4
|
-
# from .
|
|
21
|
+
from .lsr1 import LSR1
|
|
22
|
+
# from .olbfgs import OnlineLBFGS
|
|
5
23
|
|
|
6
|
-
from .
|
|
7
|
-
from .
|
|
24
|
+
# from .experimental import ModularLBFGS
|
|
25
|
+
from .quasi_newton import (
|
|
26
|
+
BFGS,
|
|
27
|
+
DFP,
|
|
28
|
+
ICUM,
|
|
29
|
+
PSB,
|
|
30
|
+
SR1,
|
|
31
|
+
SSVM,
|
|
32
|
+
BroydenBad,
|
|
33
|
+
BroydenGood,
|
|
34
|
+
FletcherVMM,
|
|
35
|
+
GradientCorrection,
|
|
36
|
+
Greenstadt1,
|
|
37
|
+
Greenstadt2,
|
|
38
|
+
Horisho,
|
|
39
|
+
McCormick,
|
|
40
|
+
NewSSM,
|
|
41
|
+
Pearson,
|
|
42
|
+
ProjectedNewtonRaphson,
|
|
43
|
+
ThomasOptimalMethod,
|
|
44
|
+
ShorR,
|
|
45
|
+
)
|
|
46
|
+
from .trust_region import CubicRegularization, TrustCG, TrustRegionBase
|