torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -24,11 +24,9 @@ def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
|
|
|
24
24
|
Projects the gradient to the eigenbases of the preconditioner.
|
|
25
25
|
"""
|
|
26
26
|
for mat in Q:
|
|
27
|
-
if mat is None:
|
|
28
|
-
if len(mat) > 0:
|
|
27
|
+
if mat is not None and len(mat) > 0:
|
|
29
28
|
tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
|
|
30
29
|
else:
|
|
31
|
-
# I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
32
30
|
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
33
31
|
tensors = tensors.permute(permute_order)
|
|
34
32
|
|
|
@@ -40,8 +38,7 @@ def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
|
|
|
40
38
|
Projects the gradient back to the original space.
|
|
41
39
|
"""
|
|
42
40
|
for mat in Q:
|
|
43
|
-
if mat is None:
|
|
44
|
-
if len(mat) > 0:
|
|
41
|
+
if mat is not None and len(mat) > 0:
|
|
45
42
|
tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
|
|
46
43
|
else:
|
|
47
44
|
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
@@ -59,8 +56,7 @@ def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
|
|
|
59
56
|
float_data = False
|
|
60
57
|
original_type = original_device = None
|
|
61
58
|
for m in mat:
|
|
62
|
-
if m is None:
|
|
63
|
-
if len(m) == 0:
|
|
59
|
+
if m is None or len(m) == 0:
|
|
64
60
|
matrix.append([])
|
|
65
61
|
continue
|
|
66
62
|
if m.dtype != torch.float:
|
|
@@ -100,13 +96,11 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
|
|
|
100
96
|
float_data = False
|
|
101
97
|
original_type = original_device = None
|
|
102
98
|
for m,o in zip(GG, Q_list):
|
|
103
|
-
if m is None:
|
|
104
|
-
assert o is not None
|
|
105
|
-
|
|
106
|
-
if len(m) == 0:
|
|
99
|
+
if m is None or len(m) == 0:
|
|
107
100
|
matrix.append([])
|
|
108
101
|
orth_matrix.append([])
|
|
109
102
|
continue
|
|
103
|
+
assert o is not None
|
|
110
104
|
if m.data.dtype != torch.float:
|
|
111
105
|
original_type = m.data.dtype
|
|
112
106
|
original_device = m.data.device
|
|
@@ -156,6 +150,24 @@ class SOAP(Transform):
|
|
|
156
150
|
learning rate. Defaults to 1.
|
|
157
151
|
bias_correction (bool, optional):
|
|
158
152
|
enables adam bias correction. Defaults to True.
|
|
153
|
+
|
|
154
|
+
Examples:
|
|
155
|
+
SOAP:
|
|
156
|
+
|
|
157
|
+
.. code-block:: python
|
|
158
|
+
|
|
159
|
+
opt = tz.Modular(model.parameters(), tz.m.SOAP(), tz.m.LR(1e-3))
|
|
160
|
+
|
|
161
|
+
Stabilized SOAP:
|
|
162
|
+
|
|
163
|
+
.. code-block:: python
|
|
164
|
+
|
|
165
|
+
opt = tz.Modular(
|
|
166
|
+
model.parameters(),
|
|
167
|
+
tz.m.SOAP(),
|
|
168
|
+
tz.m.NormalizeByEMA(max_ema_growth=1.2),
|
|
169
|
+
tz.m.LR(1e-2)
|
|
170
|
+
)
|
|
159
171
|
"""
|
|
160
172
|
def __init__(
|
|
161
173
|
self,
|
|
@@ -187,7 +199,7 @@ class SOAP(Transform):
|
|
|
187
199
|
super().__init__(defaults, uses_grad=False)
|
|
188
200
|
|
|
189
201
|
@torch.no_grad
|
|
190
|
-
def
|
|
202
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
191
203
|
updates = []
|
|
192
204
|
# update preconditioners
|
|
193
205
|
for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
|
|
@@ -200,7 +212,7 @@ class SOAP(Transform):
|
|
|
200
212
|
# initialize state on 1st step
|
|
201
213
|
if 'GG' not in state:
|
|
202
214
|
state["exp_avg"] = torch.zeros_like(t)
|
|
203
|
-
state["
|
|
215
|
+
state["exp_avg_sq_projected"] = torch.zeros_like(t)
|
|
204
216
|
|
|
205
217
|
if not precondition_1d and t.ndim <= 1:
|
|
206
218
|
state['GG'] = []
|
|
@@ -230,22 +242,20 @@ class SOAP(Transform):
|
|
|
230
242
|
# exponential moving averages
|
|
231
243
|
# this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
|
|
232
244
|
exp_avg: torch.Tensor = state["exp_avg"]
|
|
233
|
-
|
|
245
|
+
exp_avg_sq_projected: torch.Tensor = state["exp_avg_sq_projected"]
|
|
234
246
|
|
|
235
247
|
exp_avg.lerp_(t, 1-beta1)
|
|
236
248
|
|
|
237
249
|
if t_projected is None:
|
|
238
|
-
|
|
250
|
+
exp_avg_sq_projected.mul_(beta2).addcmul_(t, t, value=1-beta2)
|
|
239
251
|
else:
|
|
240
|
-
|
|
252
|
+
exp_avg_sq_projected.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
|
|
241
253
|
|
|
242
254
|
# project exponential moving averages if they are accumulated unprojected
|
|
243
255
|
exp_avg_projected = exp_avg
|
|
244
256
|
if t_projected is not None:
|
|
245
257
|
exp_avg_projected = project(exp_avg, state['Q'])
|
|
246
258
|
|
|
247
|
-
exp_avg_sq_projected = exp_avg_sq
|
|
248
|
-
|
|
249
259
|
denom = exp_avg_sq_projected.sqrt().add_(eps)
|
|
250
260
|
# print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
|
|
251
261
|
|
|
@@ -273,6 +283,6 @@ class SOAP(Transform):
|
|
|
273
283
|
if state['GG'] is not None:
|
|
274
284
|
update_soap_covariances_(t, state['GG'], shampoo_beta)
|
|
275
285
|
if state['step'] % setting['precond_freq'] == 0:
|
|
276
|
-
state['Q'], state['
|
|
286
|
+
state['Q'], state['exp_avg_sq_projected'] = get_orthogonal_matrix_QR(exp_avg_sq_projected, state['GG'], state['Q'])
|
|
277
287
|
|
|
278
288
|
return updates
|
|
@@ -35,6 +35,74 @@ def sophia_H(
|
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
class SophiaH(Module):
|
|
38
|
+
"""SophiaH optimizer from https://arxiv.org/abs/2305.14342
|
|
39
|
+
|
|
40
|
+
This is similar to Adam, but the second momentum is replaced by an exponential moving average of randomized hessian diagonal estimates, and the update is agressively clipped.
|
|
41
|
+
|
|
42
|
+
.. note::
|
|
43
|
+
In most cases SophiaH should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply SophiaH preconditioning to another module's output.
|
|
44
|
+
|
|
45
|
+
.. note::
|
|
46
|
+
If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
|
|
47
|
+
|
|
48
|
+
.. note::
|
|
49
|
+
This module requires the a closure passed to the optimizer step,
|
|
50
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
51
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
beta1 (float, optional): first momentum. Defaults to 0.96.
|
|
55
|
+
beta2 (float, optional): momentum for hessian diagonal estimate. Defaults to 0.99.
|
|
56
|
+
update_freq (int, optional):
|
|
57
|
+
frequency of updating hessian diagonal estimate via a hessian-vector product. Defaults to 10.
|
|
58
|
+
precond_scale (float, optional):
|
|
59
|
+
scale of the preconditioner. Defaults to 1.
|
|
60
|
+
clip (float, optional):
|
|
61
|
+
clips update to (-clip, clip). Defaults to 1.
|
|
62
|
+
eps (float, optional):
|
|
63
|
+
clips hessian diagonal esimate to be no less than this value. Defaults to 1e-12.
|
|
64
|
+
hvp_method (str, optional):
|
|
65
|
+
Determines how Hessian-vector products are evaluated.
|
|
66
|
+
|
|
67
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
68
|
+
This requires creating a graph for the gradient.
|
|
69
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
70
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
71
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
72
|
+
more accurate HVP approximation. This requires two extra
|
|
73
|
+
gradient evaluations.
|
|
74
|
+
Defaults to "autograd".
|
|
75
|
+
h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
|
|
76
|
+
n_samples (int, optional):
|
|
77
|
+
number of hessian-vector products with random vectors to evaluate each time when updating
|
|
78
|
+
the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
|
|
79
|
+
seed (int | None, optional): seed for random vectors. Defaults to None.
|
|
80
|
+
inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
|
|
81
|
+
|
|
82
|
+
Examples:
|
|
83
|
+
Using SophiaH:
|
|
84
|
+
|
|
85
|
+
.. code-block:: python
|
|
86
|
+
|
|
87
|
+
opt = tz.Modular(
|
|
88
|
+
model.parameters(),
|
|
89
|
+
tz.m.SophiaH(),
|
|
90
|
+
tz.m.LR(0.1)
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
SophiaH preconditioner can be applied to any other module by passing it to the :code:`inner` argument.
|
|
94
|
+
Turn off SophiaH's first momentum to get just the preconditioning. Here is an example of applying
|
|
95
|
+
SophiaH preconditioning to nesterov momentum (:code:`tz.m.NAG`):
|
|
96
|
+
|
|
97
|
+
.. code-block:: python
|
|
98
|
+
|
|
99
|
+
opt = tz.Modular(
|
|
100
|
+
model.parameters(),
|
|
101
|
+
tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
|
|
102
|
+
tz.m.LR(0.1)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
"""
|
|
38
106
|
def __init__(
|
|
39
107
|
self,
|
|
40
108
|
beta1: float = 0.96,
|
|
@@ -85,23 +153,12 @@ class SophiaH(Module):
|
|
|
85
153
|
h = None
|
|
86
154
|
if step % update_freq == 0:
|
|
87
155
|
|
|
88
|
-
|
|
156
|
+
rgrad=None
|
|
89
157
|
for i in range(n_samples):
|
|
90
158
|
u = [torch.randn(p.shape, device=p.device, dtype=p.dtype, generator=generator) for p in params]
|
|
91
159
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
assert grad is not None
|
|
95
|
-
Hvp = hvp(params, grad, u, retain_graph=i < n_samples-1)
|
|
96
|
-
|
|
97
|
-
elif hvp_method == 'forward':
|
|
98
|
-
loss, Hvp = hvp_fd_forward(closure, params, u, h=fd_h, g_0=var.get_grad(), normalize=True)
|
|
99
|
-
|
|
100
|
-
elif hvp_method == 'central':
|
|
101
|
-
loss, Hvp = hvp_fd_central(closure, params, u, h=fd_h, normalize=True)
|
|
102
|
-
|
|
103
|
-
else:
|
|
104
|
-
raise ValueError(hvp_method)
|
|
160
|
+
Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
|
|
161
|
+
h=fd_h, normalize=True, retain_grad=i < n_samples-1)
|
|
105
162
|
|
|
106
163
|
if h is None: h = Hvp
|
|
107
164
|
else: torch._foreach_add_(h, Hvp)
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from .projection import
|
|
2
|
-
from .
|
|
3
|
-
from .structural import VectorProjection, TensorizeProjection, BlockPartition, TensorNormsProjection
|
|
4
|
-
|
|
1
|
+
from .projection import ProjectionBase, VectorProjection, ScalarProjection
|
|
2
|
+
from .cast import To, ViewAsReal
|
|
5
3
|
# from .galore import GaLore
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .projection import ProjectionBase
|
|
3
|
+
from ...core import Chainable
|
|
4
|
+
|
|
5
|
+
class To(ProjectionBase):
|
|
6
|
+
"""Cast modules to specified device and dtype"""
|
|
7
|
+
def __init__(self, modules: Chainable, dtype: torch.dtype | None, device:torch.types.Device | None = None):
|
|
8
|
+
defaults = dict(dtype=dtype, device=device)
|
|
9
|
+
super().__init__(modules, project_update=True, project_params=True, project_grad=True, defaults=defaults)
|
|
10
|
+
|
|
11
|
+
@torch.no_grad
|
|
12
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
13
|
+
casted = []
|
|
14
|
+
for tensor, state, setting in zip(tensors,states, settings):
|
|
15
|
+
state['dtype'] = tensor.dtype
|
|
16
|
+
state['device'] = tensor.device
|
|
17
|
+
tensor = tensor.to(dtype=setting['dtype'], device=setting['device'])
|
|
18
|
+
casted.append(tensor)
|
|
19
|
+
return casted
|
|
20
|
+
|
|
21
|
+
@torch.no_grad
|
|
22
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
23
|
+
uncasted = []
|
|
24
|
+
for tensor, state in zip(projected_tensors, states):
|
|
25
|
+
tensor = tensor.to(dtype=state['dtype'], device=state['device'])
|
|
26
|
+
uncasted.append(tensor)
|
|
27
|
+
return uncasted
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ViewAsReal(ProjectionBase):
|
|
31
|
+
"""View complex tensors as real tensors. Doesn't affect tensors that are already."""
|
|
32
|
+
def __init__(self, modules: Chainable):
|
|
33
|
+
super().__init__(modules, project_update=True, project_params=True, project_grad=True, defaults=None)
|
|
34
|
+
|
|
35
|
+
@torch.no_grad
|
|
36
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
37
|
+
views = []
|
|
38
|
+
for tensor, state in zip(tensors,states):
|
|
39
|
+
is_complex = torch.is_complex(tensor)
|
|
40
|
+
state['is_complex'] = is_complex
|
|
41
|
+
if is_complex: tensor = torch.view_as_real(tensor)
|
|
42
|
+
views.append(tensor)
|
|
43
|
+
return views
|
|
44
|
+
|
|
45
|
+
@torch.no_grad
|
|
46
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
47
|
+
un_views = []
|
|
48
|
+
for tensor, state in zip(projected_tensors, states):
|
|
49
|
+
if state['is_complex']: tensor = torch.view_as_complex(tensor)
|
|
50
|
+
un_views.append(tensor)
|
|
51
|
+
return un_views
|