torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +54 -21
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +19 -129
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +12 -12
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +67 -17
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +12 -12
- torchzero/modules/quasi_newton/lsr1.py +11 -11
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +254 -47
- torchzero/modules/second_order/newton.py +32 -20
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +21 -21
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,7 @@ from typing import Literal
|
|
|
2
2
|
from collections.abc import Callable
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import Module, Target, Transform, Chainable,
|
|
5
|
+
from ...core import Module, Target, Transform, Chainable, apply_transform
|
|
6
6
|
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
7
|
from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
|
|
8
8
|
|
|
@@ -56,8 +56,8 @@ class SophiaH(Module):
|
|
|
56
56
|
self.set_child('inner', inner)
|
|
57
57
|
|
|
58
58
|
@torch.no_grad
|
|
59
|
-
def step(self,
|
|
60
|
-
params =
|
|
59
|
+
def step(self, var):
|
|
60
|
+
params = var.params
|
|
61
61
|
settings = self.settings[params[0]]
|
|
62
62
|
hvp_method = settings['hvp_method']
|
|
63
63
|
fd_h = settings['fd_h']
|
|
@@ -71,15 +71,15 @@ class SophiaH(Module):
|
|
|
71
71
|
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
72
72
|
generator = self.global_state['generator']
|
|
73
73
|
|
|
74
|
-
beta1, beta2, precond_scale, clip, eps = self.get_settings(
|
|
75
|
-
'beta1', 'beta2', 'precond_scale', 'clip', 'eps',
|
|
74
|
+
beta1, beta2, precond_scale, clip, eps = self.get_settings(params,
|
|
75
|
+
'beta1', 'beta2', 'precond_scale', 'clip', 'eps', cls=NumberList)
|
|
76
76
|
|
|
77
|
-
exp_avg, h_exp_avg = self.get_state('exp_avg', 'h_exp_avg',
|
|
77
|
+
exp_avg, h_exp_avg = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)
|
|
78
78
|
|
|
79
79
|
step = self.global_state.get('step', 0)
|
|
80
80
|
self.global_state['step'] = step + 1
|
|
81
81
|
|
|
82
|
-
closure =
|
|
82
|
+
closure = var.closure
|
|
83
83
|
assert closure is not None
|
|
84
84
|
|
|
85
85
|
h = None
|
|
@@ -90,12 +90,12 @@ class SophiaH(Module):
|
|
|
90
90
|
u = [torch.randn(p.shape, device=p.device, dtype=p.dtype, generator=generator) for p in params]
|
|
91
91
|
|
|
92
92
|
if hvp_method == 'autograd':
|
|
93
|
-
if grad is None: grad =
|
|
93
|
+
if grad is None: grad = var.get_grad(create_graph=True)
|
|
94
94
|
assert grad is not None
|
|
95
95
|
Hvp = hvp(params, grad, u, retain_graph=i < n_samples-1)
|
|
96
96
|
|
|
97
97
|
elif hvp_method == 'forward':
|
|
98
|
-
loss, Hvp = hvp_fd_forward(closure, params, u, h=fd_h, g_0=
|
|
98
|
+
loss, Hvp = hvp_fd_forward(closure, params, u, h=fd_h, g_0=var.get_grad(), normalize=True)
|
|
99
99
|
|
|
100
100
|
elif hvp_method == 'central':
|
|
101
101
|
loss, Hvp = hvp_fd_central(closure, params, u, h=fd_h, normalize=True)
|
|
@@ -109,11 +109,11 @@ class SophiaH(Module):
|
|
|
109
109
|
assert h is not None
|
|
110
110
|
if n_samples > 1: torch._foreach_div_(h, n_samples)
|
|
111
111
|
|
|
112
|
-
update =
|
|
112
|
+
update = var.get_update()
|
|
113
113
|
if 'inner' in self.children:
|
|
114
|
-
update =
|
|
114
|
+
update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
|
|
115
115
|
|
|
116
|
-
|
|
116
|
+
var.update = sophia_H(
|
|
117
117
|
tensors=TensorList(update),
|
|
118
118
|
h=TensorList(h) if h is not None else None,
|
|
119
119
|
exp_avg_=exp_avg,
|
|
@@ -126,4 +126,4 @@ class SophiaH(Module):
|
|
|
126
126
|
eps=eps,
|
|
127
127
|
step=step,
|
|
128
128
|
)
|
|
129
|
-
return
|
|
129
|
+
return var
|
|
@@ -34,8 +34,8 @@ class DCTProjection(Projection):
|
|
|
34
34
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
|
|
35
35
|
|
|
36
36
|
@torch.no_grad
|
|
37
|
-
def project(self, tensors,
|
|
38
|
-
settings = self.settings[
|
|
37
|
+
def project(self, tensors, var, current):
|
|
38
|
+
settings = self.settings[var.params[0]]
|
|
39
39
|
dims = settings['dims']
|
|
40
40
|
norm = settings['norm']
|
|
41
41
|
|
|
@@ -54,8 +54,8 @@ class DCTProjection(Projection):
|
|
|
54
54
|
return projected
|
|
55
55
|
|
|
56
56
|
@torch.no_grad
|
|
57
|
-
def unproject(self, tensors,
|
|
58
|
-
settings = self.settings[
|
|
57
|
+
def unproject(self, tensors, var, current):
|
|
58
|
+
settings = self.settings[var.params[0]]
|
|
59
59
|
dims = settings['dims']
|
|
60
60
|
norm = settings['norm']
|
|
61
61
|
|
|
@@ -45,8 +45,8 @@ class FFTProjection(Projection):
|
|
|
45
45
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
|
|
46
46
|
|
|
47
47
|
@torch.no_grad
|
|
48
|
-
def project(self, tensors,
|
|
49
|
-
settings = self.settings[
|
|
48
|
+
def project(self, tensors, var, current):
|
|
49
|
+
settings = self.settings[var.params[0]]
|
|
50
50
|
one_d = settings['one_d']
|
|
51
51
|
norm = settings['norm']
|
|
52
52
|
|
|
@@ -60,14 +60,14 @@ class FFTProjection(Projection):
|
|
|
60
60
|
return [torch.view_as_real(torch.fft.rfftn(t, norm=norm)) if t.numel() > 1 else t for t in tensors] # pylint:disable=not-callable
|
|
61
61
|
|
|
62
62
|
@torch.no_grad
|
|
63
|
-
def unproject(self, tensors,
|
|
64
|
-
settings = self.settings[
|
|
63
|
+
def unproject(self, tensors, var, current):
|
|
64
|
+
settings = self.settings[var.params[0]]
|
|
65
65
|
one_d = settings['one_d']
|
|
66
66
|
norm = settings['norm']
|
|
67
67
|
|
|
68
68
|
if one_d:
|
|
69
69
|
vec = torch.view_as_complex(tensors[0])
|
|
70
70
|
unprojected_vec = torch.fft.irfft(vec, n=self.global_state['length'], norm=norm) # pylint:disable=not-callable
|
|
71
|
-
return vec_to_tensors(unprojected_vec, reference=
|
|
71
|
+
return vec_to_tensors(unprojected_vec, reference=var.params)
|
|
72
72
|
|
|
73
|
-
return [torch.fft.irfftn(torch.view_as_complex(t.contiguous()), s=p.shape, norm=norm) if t.numel() > 1 else t for t, p in zip(tensors,
|
|
73
|
+
return [torch.fft.irfftn(torch.view_as_complex(t.contiguous()), s=p.shape, norm=norm) if t.numel() > 1 else t for t, p in zip(tensors, var.params)] # pylint:disable=not-callable
|
|
@@ -6,15 +6,15 @@ from typing import Any, Literal
|
|
|
6
6
|
import warnings
|
|
7
7
|
import torch
|
|
8
8
|
|
|
9
|
-
from ...core import Chainable, Module,
|
|
9
|
+
from ...core import Chainable, Module, Var
|
|
10
10
|
from ...utils import vec_to_tensors
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def _make_projected_closure(closure,
|
|
13
|
+
def _make_projected_closure(closure, var: Var, projection: "Projection",
|
|
14
14
|
params: list[torch.Tensor], projected_params: list[torch.Tensor]):
|
|
15
15
|
|
|
16
16
|
def projected_closure(backward=True):
|
|
17
|
-
unprojected_params = projection.unproject(projected_params,
|
|
17
|
+
unprojected_params = projection.unproject(projected_params, var, current='params')
|
|
18
18
|
|
|
19
19
|
with torch.no_grad():
|
|
20
20
|
for p, new_p in zip(params, unprojected_params):
|
|
@@ -23,7 +23,7 @@ def _make_projected_closure(closure, vars: Vars, projection: "Projection",
|
|
|
23
23
|
if backward:
|
|
24
24
|
loss = closure()
|
|
25
25
|
grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
26
|
-
projected_grads = projection.project(grads,
|
|
26
|
+
projected_grads = projection.project(grads, var, current='grads')
|
|
27
27
|
for p, g in zip(projected_params, projected_grads):
|
|
28
28
|
p.grad = g
|
|
29
29
|
|
|
@@ -38,15 +38,15 @@ def _projected_get_grad_override(
|
|
|
38
38
|
retain_graph: bool | None = None,
|
|
39
39
|
create_graph: bool = False,
|
|
40
40
|
projection: Any = ...,
|
|
41
|
-
|
|
41
|
+
unprojected_var: Any = ...,
|
|
42
42
|
self: Any = ...,
|
|
43
43
|
):
|
|
44
44
|
assert isinstance(projection, Projection)
|
|
45
|
-
assert isinstance(
|
|
46
|
-
assert isinstance(self,
|
|
45
|
+
assert isinstance(unprojected_var, Var)
|
|
46
|
+
assert isinstance(self, Var)
|
|
47
47
|
|
|
48
48
|
if self.grad is not None: return self.grad
|
|
49
|
-
grads =
|
|
49
|
+
grads = unprojected_var.get_grad(retain_graph, create_graph)
|
|
50
50
|
projected_grads = list(projection.project(grads, self, current='grads'))
|
|
51
51
|
self.grad = projected_grads
|
|
52
52
|
for p, g in zip(self.params, projected_grads):
|
|
@@ -85,56 +85,56 @@ class Projection(Module, ABC):
|
|
|
85
85
|
self._projected_params = None
|
|
86
86
|
|
|
87
87
|
@abstractmethod
|
|
88
|
-
def project(self, tensors: list[torch.Tensor],
|
|
88
|
+
def project(self, tensors: list[torch.Tensor], var: Var, current: Literal['params', 'grads', 'update']) -> Iterable[torch.Tensor]:
|
|
89
89
|
"""projects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""
|
|
90
90
|
|
|
91
91
|
@abstractmethod
|
|
92
|
-
def unproject(self, tensors: list[torch.Tensor],
|
|
92
|
+
def unproject(self, tensors: list[torch.Tensor], var: Var, current: Literal['params', 'grads', 'update']) -> Iterable[torch.Tensor]:
|
|
93
93
|
"""unprojects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""
|
|
94
94
|
|
|
95
95
|
@torch.no_grad
|
|
96
|
-
def step(self,
|
|
97
|
-
|
|
96
|
+
def step(self, var: Var):
|
|
97
|
+
projected_var = var.clone(clone_update=False)
|
|
98
98
|
update_is_grad = False
|
|
99
99
|
|
|
100
100
|
# closure will calculate projected update and grad if needed
|
|
101
|
-
if self._project_params and
|
|
102
|
-
if self._project_update and
|
|
101
|
+
if self._project_params and var.closure is not None:
|
|
102
|
+
if self._project_update and var.update is not None: projected_var.update = list(self.project(var.update, var=var, current='update'))
|
|
103
103
|
else:
|
|
104
104
|
update_is_grad = True
|
|
105
|
-
if self._project_grad and
|
|
105
|
+
if self._project_grad and var.grad is not None: projected_var.grad = list(self.project(var.grad, var=var, current='grads'))
|
|
106
106
|
|
|
107
107
|
# project update and grad, unprojected attributes are deleted
|
|
108
108
|
else:
|
|
109
109
|
if self._project_update:
|
|
110
|
-
if
|
|
110
|
+
if var.update is None:
|
|
111
111
|
# update is None, meaning it will be set to `grad`.
|
|
112
112
|
# we can project grad and use it for update
|
|
113
|
-
grad =
|
|
114
|
-
|
|
115
|
-
if self._project_grad:
|
|
116
|
-
else:
|
|
117
|
-
del
|
|
113
|
+
grad = var.get_grad()
|
|
114
|
+
projected_var.grad = list(self.project(grad, var=var, current='grads'))
|
|
115
|
+
if self._project_grad: projected_var.update = [g.clone() for g in projected_var.grad]
|
|
116
|
+
else: projected_var.update = projected_var.grad.copy() # don't clone because grad shouldn't be used
|
|
117
|
+
del var.update
|
|
118
118
|
update_is_grad = True
|
|
119
119
|
|
|
120
120
|
else:
|
|
121
|
-
update =
|
|
122
|
-
|
|
123
|
-
del update,
|
|
121
|
+
update = var.get_update()
|
|
122
|
+
projected_var.update = list(self.project(update, var=var, current='update'))
|
|
123
|
+
del update, var.update
|
|
124
124
|
|
|
125
|
-
if self._project_grad and
|
|
126
|
-
grad =
|
|
127
|
-
|
|
125
|
+
if self._project_grad and projected_var.grad is None:
|
|
126
|
+
grad = var.get_grad()
|
|
127
|
+
projected_var.grad = list(self.project(grad, var=var, current='grads'))
|
|
128
128
|
|
|
129
129
|
original_params = None
|
|
130
130
|
if self._project_params:
|
|
131
|
-
original_params = [p.clone() for p in
|
|
132
|
-
projected_params = self.project(
|
|
131
|
+
original_params = [p.clone() for p in var.params]
|
|
132
|
+
projected_params = self.project(var.params, var=var, current='params')
|
|
133
133
|
|
|
134
134
|
else:
|
|
135
135
|
# make fake params for correct shapes and state storage
|
|
136
136
|
# they reuse update or grad storage for memory efficiency
|
|
137
|
-
projected_params =
|
|
137
|
+
projected_params = projected_var.update if projected_var.update is not None else projected_var.grad
|
|
138
138
|
assert projected_params is not None
|
|
139
139
|
|
|
140
140
|
if self._projected_params is None:
|
|
@@ -148,22 +148,22 @@ class Projection(Module, ABC):
|
|
|
148
148
|
|
|
149
149
|
# project closure
|
|
150
150
|
if self._project_params:
|
|
151
|
-
closure =
|
|
152
|
-
|
|
151
|
+
closure = var.closure; params = var.params
|
|
152
|
+
projected_var.closure = _make_projected_closure(closure, var=var, projection=self, params=params,
|
|
153
153
|
projected_params=self._projected_params)
|
|
154
154
|
|
|
155
155
|
else:
|
|
156
|
-
|
|
156
|
+
projected_var.closure = None
|
|
157
157
|
|
|
158
158
|
# step
|
|
159
|
-
|
|
160
|
-
|
|
159
|
+
projected_var.params = self._projected_params
|
|
160
|
+
projected_var.get_grad = partial(
|
|
161
161
|
_projected_get_grad_override,
|
|
162
162
|
projection=self,
|
|
163
|
-
|
|
164
|
-
self=
|
|
163
|
+
unprojected_var=var,
|
|
164
|
+
self=projected_var,
|
|
165
165
|
)
|
|
166
|
-
|
|
166
|
+
projected_var = self.children['modules'].step(projected_var)
|
|
167
167
|
|
|
168
168
|
# empty fake params storage
|
|
169
169
|
# this doesn't affect update/grad because it is a different python object, set_ changes storage on an object
|
|
@@ -172,28 +172,28 @@ class Projection(Module, ABC):
|
|
|
172
172
|
p.set_(torch.empty(0, device=p.device, dtype=p.dtype)) # pyright: ignore[reportArgumentType]
|
|
173
173
|
|
|
174
174
|
# unproject
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
175
|
+
unprojected_var = projected_var.clone(clone_update=False)
|
|
176
|
+
unprojected_var.closure = var.closure
|
|
177
|
+
unprojected_var.params = var.params
|
|
178
|
+
unprojected_var.grad = var.grad
|
|
179
179
|
|
|
180
180
|
if self._project_update:
|
|
181
|
-
assert
|
|
182
|
-
|
|
183
|
-
del
|
|
181
|
+
assert projected_var.update is not None
|
|
182
|
+
unprojected_var.update = list(self.unproject(projected_var.update, var=var, current='grads' if update_is_grad else 'update'))
|
|
183
|
+
del projected_var.update
|
|
184
184
|
|
|
185
185
|
# unprojecting grad doesn't make sense?
|
|
186
186
|
# if self._project_grad:
|
|
187
|
-
# assert
|
|
188
|
-
#
|
|
187
|
+
# assert projected_var.grad is not None
|
|
188
|
+
# unprojected_var.grad = list(self.unproject(projected_var.grad, var=var))
|
|
189
189
|
|
|
190
|
-
del
|
|
190
|
+
del projected_var
|
|
191
191
|
|
|
192
192
|
if original_params is not None:
|
|
193
|
-
for p, o in zip(
|
|
193
|
+
for p, o in zip(unprojected_var.params, original_params):
|
|
194
194
|
p.set_(o) # pyright: ignore[reportArgumentType]
|
|
195
195
|
|
|
196
|
-
return
|
|
196
|
+
return unprojected_var
|
|
197
197
|
|
|
198
198
|
|
|
199
199
|
|
|
@@ -206,12 +206,12 @@ class FlipConcatProjection(Projection):
|
|
|
206
206
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
207
207
|
|
|
208
208
|
@torch.no_grad
|
|
209
|
-
def project(self, tensors,
|
|
209
|
+
def project(self, tensors, var, current):
|
|
210
210
|
return [torch.cat([u.view(-1) for u in tensors], dim=-1).flip(0)]
|
|
211
211
|
|
|
212
212
|
@torch.no_grad
|
|
213
|
-
def unproject(self, tensors,
|
|
214
|
-
return vec_to_tensors(vec=tensors[0].flip(0), reference=
|
|
213
|
+
def unproject(self, tensors, var, current):
|
|
214
|
+
return vec_to_tensors(vec=tensors[0].flip(0), reference=var.params)
|
|
215
215
|
|
|
216
216
|
|
|
217
217
|
class NoopProjection(Projection):
|
|
@@ -221,11 +221,11 @@ class NoopProjection(Projection):
|
|
|
221
221
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
222
222
|
|
|
223
223
|
@torch.no_grad
|
|
224
|
-
def project(self, tensors,
|
|
224
|
+
def project(self, tensors, var, current):
|
|
225
225
|
return tensors
|
|
226
226
|
|
|
227
227
|
@torch.no_grad
|
|
228
|
-
def unproject(self, tensors,
|
|
228
|
+
def unproject(self, tensors, var, current):
|
|
229
229
|
return tensors
|
|
230
230
|
|
|
231
231
|
class MultipyProjection(Projection):
|
|
@@ -235,10 +235,10 @@ class MultipyProjection(Projection):
|
|
|
235
235
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
236
236
|
|
|
237
237
|
@torch.no_grad
|
|
238
|
-
def project(self, tensors,
|
|
238
|
+
def project(self, tensors, var, current):
|
|
239
239
|
return torch._foreach_mul(tensors, 2)
|
|
240
240
|
|
|
241
241
|
@torch.no_grad
|
|
242
|
-
def unproject(self, tensors,
|
|
242
|
+
def unproject(self, tensors, var, current):
|
|
243
243
|
return torch._foreach_div(tensors, 2)
|
|
244
244
|
|
|
@@ -17,12 +17,12 @@ class VectorProjection(Projection):
|
|
|
17
17
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
18
18
|
|
|
19
19
|
@torch.no_grad
|
|
20
|
-
def project(self, tensors,
|
|
20
|
+
def project(self, tensors, var, current):
|
|
21
21
|
return [torch.cat([u.view(-1) for u in tensors], dim=-1)]
|
|
22
22
|
|
|
23
23
|
@torch.no_grad
|
|
24
|
-
def unproject(self, tensors,
|
|
25
|
-
return vec_to_tensors(vec=tensors[0], reference=
|
|
24
|
+
def unproject(self, tensors, var, current):
|
|
25
|
+
return vec_to_tensors(vec=tensors[0], reference=var.params)
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
|
|
@@ -33,8 +33,8 @@ class TensorizeProjection(Projection):
|
|
|
33
33
|
super().__init__(modules, defaults=defaults, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
34
34
|
|
|
35
35
|
@torch.no_grad
|
|
36
|
-
def project(self, tensors,
|
|
37
|
-
params =
|
|
36
|
+
def project(self, tensors, var, current):
|
|
37
|
+
params = var.params
|
|
38
38
|
max_side = self.settings[params[0]]['max_side']
|
|
39
39
|
num_elems = sum(t.numel() for t in tensors)
|
|
40
40
|
|
|
@@ -60,12 +60,12 @@ class TensorizeProjection(Projection):
|
|
|
60
60
|
return [vec.view(dims)]
|
|
61
61
|
|
|
62
62
|
@torch.no_grad
|
|
63
|
-
def unproject(self, tensors,
|
|
63
|
+
def unproject(self, tensors, var, current):
|
|
64
64
|
remainder = self.global_state['remainder']
|
|
65
65
|
# warnings.warn(f'{tensors[0].shape = }')
|
|
66
66
|
vec = tensors[0].view(-1)
|
|
67
67
|
if remainder > 0: vec = vec[:-remainder]
|
|
68
|
-
return vec_to_tensors(vec,
|
|
68
|
+
return vec_to_tensors(vec, var.params)
|
|
69
69
|
|
|
70
70
|
class BlockPartition(Projection):
|
|
71
71
|
"""splits parameters into blocks (for now flatttens them and chunks)"""
|
|
@@ -74,9 +74,9 @@ class BlockPartition(Projection):
|
|
|
74
74
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
|
|
75
75
|
|
|
76
76
|
@torch.no_grad
|
|
77
|
-
def project(self, tensors,
|
|
77
|
+
def project(self, tensors, var, current):
|
|
78
78
|
partitioned = []
|
|
79
|
-
for p,t in zip(
|
|
79
|
+
for p,t in zip(var.params, tensors):
|
|
80
80
|
settings = self.settings[p]
|
|
81
81
|
max_size = settings['max_size']
|
|
82
82
|
n = t.numel()
|
|
@@ -101,10 +101,10 @@ class BlockPartition(Projection):
|
|
|
101
101
|
return partitioned
|
|
102
102
|
|
|
103
103
|
@torch.no_grad
|
|
104
|
-
def unproject(self, tensors,
|
|
104
|
+
def unproject(self, tensors, var, current):
|
|
105
105
|
ti = iter(tensors)
|
|
106
106
|
unprojected = []
|
|
107
|
-
for p in
|
|
107
|
+
for p in var.params:
|
|
108
108
|
settings = self.settings[p]
|
|
109
109
|
n = p.numel()
|
|
110
110
|
|
|
@@ -130,19 +130,19 @@ class TensorNormsProjection(Projection):
|
|
|
130
130
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
131
131
|
|
|
132
132
|
@torch.no_grad
|
|
133
|
-
def project(self, tensors,
|
|
134
|
-
orig = self.get_state(f'{current}_orig'
|
|
133
|
+
def project(self, tensors, var, current):
|
|
134
|
+
orig = self.get_state(var.params, f'{current}_orig')
|
|
135
135
|
torch._foreach_copy_(orig, tensors)
|
|
136
136
|
|
|
137
137
|
norms = torch._foreach_norm(tensors)
|
|
138
|
-
self.get_state(f'{current}_orig_norms',
|
|
138
|
+
self.get_state(var.params, f'{current}_orig_norms', cls=TensorList).set_(norms)
|
|
139
139
|
|
|
140
140
|
return [torch.stack(norms)]
|
|
141
141
|
|
|
142
142
|
@torch.no_grad
|
|
143
|
-
def unproject(self, tensors,
|
|
144
|
-
orig = self.get_state(f'{current}_orig'
|
|
145
|
-
orig_norms = torch.stack(self.get_state(f'{current}_orig_norms'
|
|
143
|
+
def unproject(self, tensors, var, current):
|
|
144
|
+
orig = self.get_state(var.params, f'{current}_orig')
|
|
145
|
+
orig_norms = torch.stack(self.get_state(var.params, f'{current}_orig_norms'))
|
|
146
146
|
target_norms = tensors[0]
|
|
147
147
|
|
|
148
148
|
orig_norms = torch.where(orig_norms == 0, 1, orig_norms)
|
|
@@ -1,7 +1,36 @@
|
|
|
1
|
-
from .cg import
|
|
1
|
+
from .cg import (
|
|
2
|
+
ConjugateDescent,
|
|
3
|
+
DaiYuan,
|
|
4
|
+
FletcherReeves,
|
|
5
|
+
HagerZhang,
|
|
6
|
+
HestenesStiefel,
|
|
7
|
+
HybridHS_DY,
|
|
8
|
+
LiuStorey,
|
|
9
|
+
PolakRibiere,
|
|
10
|
+
ProjectedGradientMethod,
|
|
11
|
+
)
|
|
2
12
|
from .lbfgs import LBFGS
|
|
13
|
+
from .lsr1 import LSR1
|
|
3
14
|
from .olbfgs import OnlineLBFGS
|
|
4
|
-
# from .experimental import ModularLBFGS
|
|
5
15
|
|
|
6
|
-
from .
|
|
7
|
-
from .
|
|
16
|
+
# from .experimental import ModularLBFGS
|
|
17
|
+
from .quasi_newton import (
|
|
18
|
+
BFGS,
|
|
19
|
+
DFP,
|
|
20
|
+
PSB,
|
|
21
|
+
SR1,
|
|
22
|
+
SSVM,
|
|
23
|
+
BroydenBad,
|
|
24
|
+
BroydenGood,
|
|
25
|
+
ColumnUpdatingMethod,
|
|
26
|
+
FletcherVMM,
|
|
27
|
+
GradientCorrection,
|
|
28
|
+
Greenstadt1,
|
|
29
|
+
Greenstadt2,
|
|
30
|
+
Horisho,
|
|
31
|
+
McCormick,
|
|
32
|
+
NewSSM,
|
|
33
|
+
Pearson,
|
|
34
|
+
ProjectedNewtonRaphson,
|
|
35
|
+
ThomasOptimalMethod,
|
|
36
|
+
)
|