mlx-nufft 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_nufft/__init__.py +48 -0
- mlx_nufft/api.py +388 -0
- mlx_nufft/dfmath.py +134 -0
- mlx_nufft/gpu_t3.py +1364 -0
- mlx_nufft/nd.py +1119 -0
- mlx_nufft/ref_t3.py +228 -0
- mlx_nufft/sizing.py +111 -0
- mlx_nufft/types12.py +290 -0
- mlx_nufft/vkfft_backend.py +94 -0
- mlx_nufft-0.1.1.dist-info/METADATA +202 -0
- mlx_nufft-0.1.1.dist-info/RECORD +15 -0
- mlx_nufft-0.1.1.dist-info/WHEEL +5 -0
- mlx_nufft-0.1.1.dist-info/licenses/LICENSE +216 -0
- mlx_nufft-0.1.1.dist-info/licenses/NOTICE +28 -0
- mlx_nufft-0.1.1.dist-info/top_level.txt +1 -0
mlx_nufft/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""mlx-nufft: non-uniform FFTs on Apple GPUs (Metal, via MLX).
|
|
2
|
+
|
|
3
|
+
Drop-in mirror of the `finufft` Python package's interface:
|
|
4
|
+
|
|
5
|
+
import mlx_nufft as finufft
|
|
6
|
+
fk = finufft.nufft2d1(x, y, c, (N1, N2), eps=1e-6) # types 1/2/3, dims 1/2/3
|
|
7
|
+
plan = finufft.Plan(1, (N1, N2), eps=1e-6)
|
|
8
|
+
plan.setpts(x, y)
|
|
9
|
+
fk = plan.execute(c)
|
|
10
|
+
|
|
11
|
+
plus the native plan classes:
|
|
12
|
+
|
|
13
|
+
from mlx_nufft import Type3Plan
|
|
14
|
+
plan = Type3Plan((x1, x2, x3), (s1, s2, s3), eps=1e-5, isign=+1)
|
|
15
|
+
f = plan.execute(c) # f[k] = sum_j c[j] exp(i*isign*s_k.x_j)
|
|
16
|
+
|
|
17
|
+
Precision model (see REPORT.md): fp32 GPU pipeline with the precision-
|
|
18
|
+
critical setup (coordinate rescale, pre/post phases) in fp64 at plan time
|
|
19
|
+
('crit64', the default). Plans cache all geometry-dependent state, so
|
|
20
|
+
fixed-geometry workloads pay setup once and execute() per call.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from .gpu_t3 import GpuT3Plan as Type3Plan
|
|
24
|
+
from .types12 import Type1Plan, Type2Plan
|
|
25
|
+
from .nd import Type1PlanND, Type2PlanND
|
|
26
|
+
from .dfmath import expi, EXPI_MAX_PHASE
|
|
27
|
+
from .sizing import kernel_params, next235even
|
|
28
|
+
from .api import (Plan,
|
|
29
|
+
nufft1d1, nufft1d2, nufft1d3,
|
|
30
|
+
nufft2d1, nufft2d2, nufft2d3,
|
|
31
|
+
nufft3d1, nufft3d2, nufft3d3)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def vkfft_available():
|
|
35
|
+
"""True if the optional VkFFT-Metal FFT backend (fft_backend='vkfft') is
|
|
36
|
+
built and loadable — see vkfft_bridge/build.sh."""
|
|
37
|
+
from . import vkfft_backend
|
|
38
|
+
return vkfft_backend.available()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
__version__ = "0.1.1"
|
|
42
|
+
__all__ = ["Type3Plan", "Type1Plan", "Type2Plan",
|
|
43
|
+
"Type1PlanND", "Type2PlanND", "Plan",
|
|
44
|
+
"nufft1d1", "nufft1d2", "nufft1d3",
|
|
45
|
+
"nufft2d1", "nufft2d2", "nufft2d3",
|
|
46
|
+
"nufft3d1", "nufft3d2", "nufft3d3",
|
|
47
|
+
"kernel_params", "next235even", "vkfft_available",
|
|
48
|
+
"expi", "EXPI_MAX_PHASE"]
|
mlx_nufft/api.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
"""finufft-compatible API for mlx_nufft.
|
|
2
|
+
|
|
3
|
+
Drop-in mirror of the `finufft` Python package's interface, so existing
|
|
4
|
+
callers switch with one import line:
|
|
5
|
+
|
|
6
|
+
import mlx_nufft as finufft
|
|
7
|
+
fk = finufft.nufft2d1(x, y, c, (N1, N2), eps=1e-6)
|
|
8
|
+
plan = finufft.Plan(1, (N1, N2), n_trans=8, eps=1e-6)
|
|
9
|
+
plan.setpts(x, y)
|
|
10
|
+
fk = plan.execute(c)
|
|
11
|
+
|
|
12
|
+
Semantics mirrored from finufft 2.x:
|
|
13
|
+
- mode boxes are modeord=0: k_d integer in [-(N_d//2), (N_d-1)//2],
|
|
14
|
+
even or odd N_d;
|
|
15
|
+
- isign: non-negative means +i in the exponential (type-1/3 default +1,
|
|
16
|
+
type-2 default -1);
|
|
17
|
+
- multi-vector inputs: leading n_trans axis on strengths/mode arrays;
|
|
18
|
+
- out= arrays are filled in place when supplied;
|
|
19
|
+
- x (etc.) in [-pi, pi), folded otherwise.
|
|
20
|
+
|
|
21
|
+
Differences (documented, not silent):
|
|
22
|
+
- computation is the validated fp32 GPU pipeline with fp64-critical
|
|
23
|
+
setup ('crit64'); requesting eps below 1e-6 clamps to 1e-6 with a
|
|
24
|
+
warning (the fp32 accuracy envelope: see ACCEPTANCE.md);
|
|
25
|
+
- complex128 inputs are accepted and returned as complex128, but the
|
|
26
|
+
transform itself is fp32-grade;
|
|
27
|
+
- modeord=1 (FFT ordering) is not implemented (raises);
|
|
28
|
+
- 1D/2D type 3 currently run as degenerate slices of the validated 3D
|
|
29
|
+
type-3 kernel (functional; not speed-tuned).
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
import warnings
|
|
33
|
+
|
|
34
|
+
import numpy as np
|
|
35
|
+
|
|
36
|
+
from .nd import Type1PlanND, Type2PlanND
|
|
37
|
+
from .gpu_t3 import GpuT3Plan
|
|
38
|
+
|
|
39
|
+
_EPS_FLOOR = 1e-6
|
|
40
|
+
_IGNORED_OPTS = {
|
|
41
|
+
"nthreads", "debug", "spread_debug", "showwarn", "fftw", "spread_sort",
|
|
42
|
+
"spread_kerevalmeth", "spread_kerpad", "chkbnds", "maxbatchsize",
|
|
43
|
+
"spread_thread", "spread_nthr_atomic", "spread_max_sp_size",
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _check_opts(kwargs):
|
|
48
|
+
opts = dict(kwargs)
|
|
49
|
+
if opts.pop("modeord", 0) not in (0,):
|
|
50
|
+
raise NotImplementedError("modeord=1 (FFT ordering) not implemented")
|
|
51
|
+
upsampfac = opts.pop("upsampfac", None)
|
|
52
|
+
if upsampfac in (0, 0.0): # finufft auto sentinel
|
|
53
|
+
upsampfac = None
|
|
54
|
+
prec = opts.pop("prec", "crit64")
|
|
55
|
+
fft_backend = opts.pop("fft_backend", "mlx") # type-3 slab only
|
|
56
|
+
for k in list(opts):
|
|
57
|
+
if k in _IGNORED_OPTS:
|
|
58
|
+
opts.pop(k)
|
|
59
|
+
if opts:
|
|
60
|
+
warnings.warn(f"mlx-nufft: ignoring unknown options {sorted(opts)}")
|
|
61
|
+
return upsampfac, prec, fft_backend
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _norm_eps(eps):
|
|
65
|
+
eps = float(eps)
|
|
66
|
+
if eps < _EPS_FLOOR:
|
|
67
|
+
warnings.warn(
|
|
68
|
+
f"mlx-nufft: eps={eps:g} is below the fp32 pipeline floor; "
|
|
69
|
+
f"clamping to {_EPS_FLOOR:g} (see ACCEPTANCE.md accuracy notes)")
|
|
70
|
+
eps = _EPS_FLOOR
|
|
71
|
+
return eps
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _norm_isign(isign, default):
|
|
75
|
+
if isign is None:
|
|
76
|
+
return default
|
|
77
|
+
return +1 if isign >= 0 else -1
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _out_dtype(*arrays):
|
|
81
|
+
for a in arrays:
|
|
82
|
+
if np.asarray(a).dtype in (np.complex128, np.float64):
|
|
83
|
+
return np.complex128
|
|
84
|
+
return np.complex64
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _vec_shape(data, inner_ndim, inner_shape=None):
|
|
88
|
+
"""Split data shape into (n_trans, inner shape); inner_ndim trailing.
|
|
89
|
+
If inner_shape is given, the trailing dims must match it exactly
|
|
90
|
+
(mirrors FINUFFT's strict size checks — no silent truncation)."""
|
|
91
|
+
data = np.asarray(data)
|
|
92
|
+
if data.ndim == inner_ndim:
|
|
93
|
+
out = 1, data[None, ...]
|
|
94
|
+
elif data.ndim == inner_ndim + 1:
|
|
95
|
+
out = data.shape[0], data
|
|
96
|
+
else:
|
|
97
|
+
raise ValueError(f"data must have {inner_ndim} or {inner_ndim + 1} "
|
|
98
|
+
f"dims, got shape {data.shape}")
|
|
99
|
+
if inner_shape is not None and out[1].shape[1:] != tuple(inner_shape):
|
|
100
|
+
raise ValueError(f"data inner shape {out[1].shape[1:]} must be "
|
|
101
|
+
f"{tuple(inner_shape)}")
|
|
102
|
+
return out
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _fill_out(out, res, dtype):
|
|
106
|
+
res = res.astype(dtype, copy=False)
|
|
107
|
+
if out is not None:
|
|
108
|
+
# exact shape, or the (1, ...) stacked form when n_trans == 1
|
|
109
|
+
if out.shape != res.shape and out.shape != (1,) + res.shape \
|
|
110
|
+
and (1,) + out.shape != res.shape:
|
|
111
|
+
raise ValueError(f"out.shape {out.shape} does not match result "
|
|
112
|
+
f"shape {res.shape}")
|
|
113
|
+
np.copyto(out, res.reshape(out.shape))
|
|
114
|
+
return out
|
|
115
|
+
return res
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _embed3(arrs, dim):
|
|
119
|
+
"""Zero-pad a dim<3 coordinate tuple to 3 components for GpuT3Plan."""
|
|
120
|
+
arrs = [np.asarray(a, dtype=np.float64).ravel() for a in arrs]
|
|
121
|
+
z = np.zeros(arrs[0].size)
|
|
122
|
+
return tuple(arrs) + (z,) * (3 - dim)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _modes_tuple(n_modes, dim, out, out_offset=0):
|
|
126
|
+
if n_modes is None:
|
|
127
|
+
if out is None:
|
|
128
|
+
raise ValueError("either n_modes or out must be supplied")
|
|
129
|
+
shape = out.shape[out_offset:]
|
|
130
|
+
if len(shape) != dim:
|
|
131
|
+
raise ValueError(f"out shape {out.shape} does not match dim {dim}")
|
|
132
|
+
return tuple(int(n) for n in shape)
|
|
133
|
+
if np.isscalar(n_modes):
|
|
134
|
+
return (int(n_modes),) * dim
|
|
135
|
+
return tuple(int(n) for n in n_modes)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
# ---------------------------------------------------------------------------
|
|
139
|
+
# type 1: nonuniform -> uniform
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _warn_no_vkfft(fft_backend, what):
|
|
143
|
+
if fft_backend != "mlx":
|
|
144
|
+
warnings.warn(f"mlx-nufft: fft_backend={fft_backend!r} applies only "
|
|
145
|
+
f"to 3D type-3 (slab); ignored for {what}")
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _nufft_t1(dim, coords, c, n_modes, out, eps, isign, kwargs):
|
|
149
|
+
upsampfac, prec, fft_backend = _check_opts(kwargs)
|
|
150
|
+
_warn_no_vkfft(fft_backend, "type-1")
|
|
151
|
+
eps = _norm_eps(eps)
|
|
152
|
+
isign = _norm_isign(isign, +1)
|
|
153
|
+
dtype = _out_dtype(c)
|
|
154
|
+
M = np.asarray(coords[0]).size
|
|
155
|
+
n_tr, cv = _vec_shape(c, 1, inner_shape=(M,))
|
|
156
|
+
if out is not None and n_modes is None:
|
|
157
|
+
n_modes = _modes_tuple(None, dim, out, out_offset=(1 if n_tr > 1 else 0))
|
|
158
|
+
N = _modes_tuple(n_modes, dim, out)
|
|
159
|
+
kw = {} if upsampfac is None else {"upsampfac": upsampfac}
|
|
160
|
+
plan = Type1PlanND(coords, N, eps=eps, isign=isign, prec=prec, **kw)
|
|
161
|
+
res = np.stack([plan.execute(cv[t]) for t in range(n_tr)])
|
|
162
|
+
if n_tr == 1 and (np.asarray(c).ndim == 1):
|
|
163
|
+
res = res[0]
|
|
164
|
+
return _fill_out(out, res, dtype)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def nufft1d1(x, c, n_modes=None, out=None, eps=1e-6, isign=1, **kwargs):
|
|
168
|
+
"""1D type-1: f[k] = sum_j c[j] exp(+/-i k x(j))."""
|
|
169
|
+
return _nufft_t1(1, (x,), c, n_modes, out, eps, isign, kwargs)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def nufft2d1(x, y, c, n_modes=None, out=None, eps=1e-6, isign=1, **kwargs):
|
|
173
|
+
"""2D type-1: f[k1,k2] = sum_j c[j] exp(+/-i (k1 x(j) + k2 y(j)))."""
|
|
174
|
+
return _nufft_t1(2, (x, y), c, n_modes, out, eps, isign, kwargs)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def nufft3d1(x, y, z, c, n_modes=None, out=None, eps=1e-6, isign=1, **kwargs):
|
|
178
|
+
"""3D type-1: f[k1,k2,k3] = sum_j c[j] exp(+/-i k . x_j)."""
|
|
179
|
+
return _nufft_t1(3, (x, y, z), c, n_modes, out, eps, isign, kwargs)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
# ---------------------------------------------------------------------------
|
|
183
|
+
# type 2: uniform -> nonuniform
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _nufft_t2(dim, coords, f, out, eps, isign, kwargs):
|
|
187
|
+
upsampfac, prec, fft_backend = _check_opts(kwargs)
|
|
188
|
+
_warn_no_vkfft(fft_backend, "type-2")
|
|
189
|
+
eps = _norm_eps(eps)
|
|
190
|
+
isign = _norm_isign(isign, -1)
|
|
191
|
+
dtype = _out_dtype(f)
|
|
192
|
+
n_tr, fv = _vec_shape(f, dim)
|
|
193
|
+
N = fv.shape[1:]
|
|
194
|
+
kw = {} if upsampfac is None else {"upsampfac": upsampfac}
|
|
195
|
+
plan = Type2PlanND(coords, N, eps=eps, isign=isign, prec=prec, **kw)
|
|
196
|
+
res = np.stack([plan.execute(fv[t]) for t in range(n_tr)])
|
|
197
|
+
if n_tr == 1 and (np.asarray(f).ndim == dim):
|
|
198
|
+
res = res[0]
|
|
199
|
+
return _fill_out(out, res, dtype)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def nufft1d2(x, f, out=None, eps=1e-6, isign=-1, **kwargs):
|
|
203
|
+
"""1D type-2: c[j] = sum_k f[k] exp(+/-i k x(j))."""
|
|
204
|
+
return _nufft_t2(1, (x,), f, out, eps, isign, kwargs)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def nufft2d2(x, y, f, out=None, eps=1e-6, isign=-1, **kwargs):
|
|
208
|
+
"""2D type-2: c[j] = sum_{k1,k2} f[k1,k2] exp(+/-i (k1 x(j) + k2 y(j)))."""
|
|
209
|
+
return _nufft_t2(2, (x, y), f, out, eps, isign, kwargs)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def nufft3d2(x, y, z, f, out=None, eps=1e-6, isign=-1, **kwargs):
|
|
213
|
+
"""3D type-2: c[j] = sum_k f[k] exp(+/-i k . x_j)."""
|
|
214
|
+
return _nufft_t2(3, (x, y, z), f, out, eps, isign, kwargs)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
# ---------------------------------------------------------------------------
|
|
218
|
+
# type 3: nonuniform -> nonuniform
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _nufft_t3(dim, src, c, trg, out, eps, isign, kwargs):
|
|
222
|
+
upsampfac, prec, fft_backend = _check_opts(kwargs)
|
|
223
|
+
if upsampfac is not None and upsampfac != 1.25:
|
|
224
|
+
warnings.warn("mlx-nufft: type-3 runs the validated sigma=1.25 "
|
|
225
|
+
"pipeline; upsampfac ignored")
|
|
226
|
+
eps = _norm_eps(eps)
|
|
227
|
+
isign = _norm_isign(isign, +1)
|
|
228
|
+
dtype = _out_dtype(c)
|
|
229
|
+
M = np.asarray(src[0]).size
|
|
230
|
+
n_tr, cv = _vec_shape(c, 1, inner_shape=(M,))
|
|
231
|
+
plan = GpuT3Plan(_embed3(src, dim), _embed3(trg, dim),
|
|
232
|
+
eps=eps, isign=isign, prec=prec, fft_backend=fft_backend)
|
|
233
|
+
res = np.stack([plan.execute(cv[t]) for t in range(n_tr)])
|
|
234
|
+
if n_tr == 1 and (np.asarray(c).ndim == 1):
|
|
235
|
+
res = res[0]
|
|
236
|
+
return _fill_out(out, res, dtype)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def nufft1d3(x, c, s, out=None, eps=1e-6, isign=1, **kwargs):
|
|
240
|
+
"""1D type-3: f[k] = sum_j c[j] exp(+/-i s[k] x[j])."""
|
|
241
|
+
return _nufft_t3(1, (x,), c, (s,), out, eps, isign, kwargs)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def nufft2d3(x, y, c, s, t, out=None, eps=1e-6, isign=1, **kwargs):
|
|
245
|
+
"""2D type-3: f[k] = sum_j c[j] exp(+/-i (s[k] x[j] + t[k] y[j]))."""
|
|
246
|
+
return _nufft_t3(2, (x, y), c, (s, t), out, eps, isign, kwargs)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def nufft3d3(x, y, z, c, s, t, u, out=None, eps=1e-6, isign=1, **kwargs):
|
|
250
|
+
"""3D type-3: f[k] = sum_j c[j] exp(+/-i (s,t,u)_k . (x,y,z)_j)."""
|
|
251
|
+
return _nufft_t3(3, (x, y, z), c, (s, t, u), out, eps, isign, kwargs)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
# ---------------------------------------------------------------------------
|
|
255
|
+
# Plan interface
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class Plan:
|
|
259
|
+
"""finufft.Plan-compatible plan/setpts/execute interface.
|
|
260
|
+
|
|
261
|
+
Plan(nufft_type, n_modes_or_dim, n_trans=1, eps=1e-6, isign=None,
|
|
262
|
+
dtype='complex128', **kwargs)
|
|
263
|
+
|
|
264
|
+
For types 1/2, n_modes_or_dim is the mode tuple (dim inferred from its
|
|
265
|
+
length). For type 3 it is the dimension (1, 2 or 3). setpts() builds the
|
|
266
|
+
GPU plan (points are part of plan state, as in cu/FINUFFT); execute()
|
|
267
|
+
runs each of n_trans vectors through the cached plan.
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
def __init__(self, nufft_type, n_modes_or_dim, n_trans=1, eps=1e-6,
|
|
271
|
+
isign=None, dtype="complex128", **kwargs):
|
|
272
|
+
if nufft_type not in (1, 2, 3):
|
|
273
|
+
raise ValueError("nufft_type must be 1, 2 or 3")
|
|
274
|
+
self.type = int(nufft_type)
|
|
275
|
+
self.n_trans = int(n_trans)
|
|
276
|
+
self.eps = _norm_eps(eps)
|
|
277
|
+
self.isign = _norm_isign(isign, -1 if self.type == 2 else +1)
|
|
278
|
+
self.dtype = np.dtype(dtype)
|
|
279
|
+
if self.dtype not in (np.complex64, np.complex128):
|
|
280
|
+
raise ValueError("dtype must be complex64 or complex128")
|
|
281
|
+
self._upsampfac, self._prec, self._fft_backend = _check_opts(kwargs)
|
|
282
|
+
if self._fft_backend != "mlx" and self.type != 3:
|
|
283
|
+
_warn_no_vkfft(self._fft_backend, f"type-{self.type}")
|
|
284
|
+
self._fft_backend = "mlx"
|
|
285
|
+
if self.type == 3:
|
|
286
|
+
self.dim = int(n_modes_or_dim)
|
|
287
|
+
self.n_modes = None
|
|
288
|
+
else:
|
|
289
|
+
if np.isscalar(n_modes_or_dim):
|
|
290
|
+
n_modes_or_dim = (n_modes_or_dim,)
|
|
291
|
+
self.n_modes = tuple(int(n) for n in n_modes_or_dim)
|
|
292
|
+
self.dim = len(self.n_modes)
|
|
293
|
+
if self.dim not in (1, 2, 3):
|
|
294
|
+
raise ValueError("dim must be 1, 2 or 3")
|
|
295
|
+
self._plan = None
|
|
296
|
+
self._adjoint = None
|
|
297
|
+
self._n_targets = None
|
|
298
|
+
|
|
299
|
+
def setpts(self, x=None, y=None, z=None, s=None, t=None, u=None):
|
|
300
|
+
coords = [v for v in (x, y, z) if v is not None]
|
|
301
|
+
if len(coords) != self.dim:
|
|
302
|
+
raise ValueError(f"expected {self.dim} coordinate arrays, "
|
|
303
|
+
f"got {len(coords)}")
|
|
304
|
+
kw = {} if self._upsampfac is None else {"upsampfac": self._upsampfac}
|
|
305
|
+
self._adjoint = None
|
|
306
|
+
if self.type == 1:
|
|
307
|
+
self._plan = Type1PlanND(tuple(coords), self.n_modes,
|
|
308
|
+
eps=self.eps, isign=self.isign,
|
|
309
|
+
prec=self._prec, **kw)
|
|
310
|
+
elif self.type == 2:
|
|
311
|
+
self._plan = Type2PlanND(tuple(coords), self.n_modes,
|
|
312
|
+
eps=self.eps, isign=self.isign,
|
|
313
|
+
prec=self._prec, **kw)
|
|
314
|
+
else:
|
|
315
|
+
if self._upsampfac is not None and self._upsampfac != 1.25:
|
|
316
|
+
warnings.warn("mlx-nufft: type-3 runs the validated "
|
|
317
|
+
"sigma=1.25 pipeline; upsampfac ignored")
|
|
318
|
+
trg = [v for v in (s, t, u) if v is not None]
|
|
319
|
+
if len(trg) != self.dim:
|
|
320
|
+
raise ValueError(f"expected {self.dim} target arrays, "
|
|
321
|
+
f"got {len(trg)}")
|
|
322
|
+
self._plan = GpuT3Plan(_embed3(coords, self.dim),
|
|
323
|
+
_embed3(trg, self.dim),
|
|
324
|
+
eps=self.eps, isign=self.isign,
|
|
325
|
+
prec=self._prec, fft_backend=self._fft_backend)
|
|
326
|
+
self._n_targets = np.asarray(trg[0]).size
|
|
327
|
+
self._coords = [np.asarray(v) for v in coords]
|
|
328
|
+
self._targets = None if self.type != 3 else \
|
|
329
|
+
[np.asarray(v) for v in (s, t, u) if v is not None]
|
|
330
|
+
|
|
331
|
+
def execute(self, data, out=None):
|
|
332
|
+
if self._plan is None:
|
|
333
|
+
raise RuntimeError("setpts() must be called before execute()")
|
|
334
|
+
if self.type == 2:
|
|
335
|
+
inner, ishape = self.dim, self.n_modes
|
|
336
|
+
else:
|
|
337
|
+
inner, ishape = 1, (self._plan.P,)
|
|
338
|
+
n_tr, dv = _vec_shape(data, inner, inner_shape=ishape)
|
|
339
|
+
if n_tr != self.n_trans:
|
|
340
|
+
raise ValueError(f"data has {n_tr} vectors, plan has "
|
|
341
|
+
f"n_trans={self.n_trans}")
|
|
342
|
+
res = np.stack([self._plan.execute(dv[k]) for k in range(self.n_trans)])
|
|
343
|
+
if self.n_trans == 1 and np.asarray(data).ndim == inner:
|
|
344
|
+
res = res[0]
|
|
345
|
+
return _fill_out(out, res, self.dtype)
|
|
346
|
+
|
|
347
|
+
def execute_adjoint(self, data, out=None):
|
|
348
|
+
"""Apply the adjoint of the planned transform (finufft 2.5 API).
|
|
349
|
+
|
|
350
|
+
Type-1 plan adjoint maps modes -> points; type-2 adjoint maps
|
|
351
|
+
points -> modes; type-3 adjoint maps targets -> sources. Implemented
|
|
352
|
+
as the sibling transform with isign negated (the exact adjoint of
|
|
353
|
+
the NUFFT matrix; the validated type-3 adjoint identity)."""
|
|
354
|
+
if self._plan is None:
|
|
355
|
+
raise RuntimeError("setpts() must be called before "
|
|
356
|
+
"execute_adjoint()")
|
|
357
|
+
if self._adjoint is None:
|
|
358
|
+
kw = {} if self._upsampfac is None \
|
|
359
|
+
else {"upsampfac": self._upsampfac}
|
|
360
|
+
if self.type == 1:
|
|
361
|
+
self._adjoint = Type2PlanND(tuple(self._coords), self.n_modes,
|
|
362
|
+
eps=self.eps, isign=-self.isign,
|
|
363
|
+
prec=self._prec, **kw)
|
|
364
|
+
elif self.type == 2:
|
|
365
|
+
self._adjoint = Type1PlanND(tuple(self._coords), self.n_modes,
|
|
366
|
+
eps=self.eps, isign=-self.isign,
|
|
367
|
+
prec=self._prec, **kw)
|
|
368
|
+
else:
|
|
369
|
+
self._adjoint = GpuT3Plan(_embed3(self._targets, self.dim),
|
|
370
|
+
_embed3(self._coords, self.dim),
|
|
371
|
+
eps=self.eps, isign=-self.isign,
|
|
372
|
+
prec=self._prec,
|
|
373
|
+
fft_backend=self._fft_backend)
|
|
374
|
+
if self.type == 1:
|
|
375
|
+
inner, ishape = self.dim, self.n_modes
|
|
376
|
+
elif self.type == 2:
|
|
377
|
+
inner, ishape = 1, (self._plan.P,)
|
|
378
|
+
else:
|
|
379
|
+
inner, ishape = 1, (self._n_targets,)
|
|
380
|
+
n_tr, dv = _vec_shape(data, inner, inner_shape=ishape)
|
|
381
|
+
if n_tr != self.n_trans:
|
|
382
|
+
raise ValueError(f"data has {n_tr} vectors, plan has "
|
|
383
|
+
f"n_trans={self.n_trans}")
|
|
384
|
+
res = np.stack([self._adjoint.execute(dv[k])
|
|
385
|
+
for k in range(self.n_trans)])
|
|
386
|
+
if self.n_trans == 1 and np.asarray(data).ndim == inner:
|
|
387
|
+
res = res[0]
|
|
388
|
+
return _fill_out(out, res, self.dtype)
|
mlx_nufft/dfmath.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""General double-single (df64) GPU math primitives.
|
|
2
|
+
|
|
3
|
+
These expose the extended-precision machinery the NUFFT plans use internally
|
|
4
|
+
(df64 phase accumulation + reduction mod 2pi, then f32 transcendentals) as
|
|
5
|
+
standalone tools for callers that hit the same f32-precision wall — most
|
|
6
|
+
commonly a large-magnitude fp64 phase that cannot be reduced mod 2pi in f32
|
|
7
|
+
but whose cos/sin you want to evaluate on the GPU.
|
|
8
|
+
|
|
9
|
+
The reduction is identical to the prephase inside GpuT3Plan.set_sources: form
|
|
10
|
+
the phase in df64, k = rint(phi/2pi), then phi - k*2pi in df64 (the product
|
|
11
|
+
k*2pi is exact via two_prod), and f32 cos/sin of the O(1) residual. Valid while
|
|
12
|
+
the integer quotient phi/2pi stays f32-exact, i.e. |phi| <~ 2^24 * 2pi ~ 1.05e8
|
|
13
|
+
radians; beyond that the residual grows and accuracy degrades gracefully.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import mlx.core as mx
|
|
18
|
+
|
|
19
|
+
from .gpu_t3 import _DF64_HDR
|
|
20
|
+
|
|
21
|
+
PI = np.pi
|
|
22
|
+
|
|
23
|
+
# valid magnitude ceiling: |phi| where rint(phi/2pi) stays an exact f32 integer
|
|
24
|
+
EXPI_MAX_PHASE = float(2 ** 24 * 2.0 * PI) # ~1.054e8 radians
|
|
25
|
+
|
|
26
|
+
_expi_cache = {}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _build_expi_kernel(ncomp):
|
|
30
|
+
"""e^{i*isign*phi} for phi = sum of `ncomp` df64 phase components:
|
|
31
|
+
accumulate in df64, reduce mod 2pi (k = rint(phi/2pi); phi - k*2pi in
|
|
32
|
+
df64, k*2pi exact via two_prod), then f32 cos/sin of the O(1) residual.
|
|
33
|
+
|
|
34
|
+
consts: [2pi_hi, 2pi_lo, 1/2pi, isign]."""
|
|
35
|
+
acc = " df64 ph = df_make(0.0f, 0.0f);\n"
|
|
36
|
+
for c in range(ncomp):
|
|
37
|
+
acc += f" ph = df_add(ph, df_make(ph_hi{c}[j], ph_lo{c}[j]));\n"
|
|
38
|
+
src = f"""
|
|
39
|
+
uint j = thread_position_in_grid.x;
|
|
40
|
+
if (j >= (uint)P0[0]) return;
|
|
41
|
+
{acc} float k = metal::rint(ph.hi * cst[2]); // ph.hi / 2pi
|
|
42
|
+
df64 red = df_add(ph, df_mul(df_make(-k, 0.0f),
|
|
43
|
+
df_make(cst[0], cst[1]))); // ph - k*2pi (df64)
|
|
44
|
+
float ang = cst[3] * (red.hi + red.lo);
|
|
45
|
+
out[2*j] = metal::precise::cos(ang);
|
|
46
|
+
out[2*j+1] = metal::precise::sin(ang);
|
|
47
|
+
"""
|
|
48
|
+
innames = []
|
|
49
|
+
for c in range(ncomp):
|
|
50
|
+
innames += [f"ph_hi{c}", f"ph_lo{c}"]
|
|
51
|
+
innames += ["cst", "P0"]
|
|
52
|
+
return mx.fast.metal_kernel(
|
|
53
|
+
name=f"expi_df64_n{ncomp}", input_names=innames,
|
|
54
|
+
output_names=["out"], header=_DF64_HDR, source=src)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def expi(phases, isign=1, return_np=False):
|
|
58
|
+
"""Compute e^{i*isign*phi} on the GPU for a large-magnitude fp64 phase via
|
|
59
|
+
a double-single (df64) reduction mod 2pi followed by f32 cos/sin — the
|
|
60
|
+
general form of the prephase machinery inside GpuT3Plan.set_sources.
|
|
61
|
+
|
|
62
|
+
Use this whenever an fp64 phase is too large to reduce in f32 (f32 loses
|
|
63
|
+
all fractional bits by |phi| ~ 1e3) but you want the cos/sin on the GPU.
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
phases : fp64 array, or a sequence of fp64 arrays
|
|
68
|
+
Either the phase phi directly, or a small set of phase components of
|
|
69
|
+
identical shape that are summed *in df64* to form phi = sum_k phases[k]
|
|
70
|
+
(more accurate than an fp64 host sum, and keeps the sum on-GPU).
|
|
71
|
+
isign : int
|
|
72
|
+
+1 (default) or -1; computes e^{i*isign*phi}.
|
|
73
|
+
return_np : bool
|
|
74
|
+
If True, copy the result to a numpy complex64 array; otherwise return
|
|
75
|
+
the device mx.array (complex64).
|
|
76
|
+
|
|
77
|
+
Returns
|
|
78
|
+
-------
|
|
79
|
+
e^{i*isign*phi} as complex64, same shape as the input.
|
|
80
|
+
|
|
81
|
+
Accuracy: matches an fp64 host reference (np.exp) to ~1e-6 per element for
|
|
82
|
+
|phi| up to EXPI_MAX_PHASE (~1.05e8 rad) — approaching ~2e-6 right at the
|
|
83
|
+
ceiling, ~2e-7 for |phi| <~ 1e7. Beyond the ceiling the integer quotient
|
|
84
|
+
phi/2pi is no longer f32-exact and accuracy degrades gracefully (~linear in
|
|
85
|
+
|phi|, no cliff).
|
|
86
|
+
|
|
87
|
+
Notes
|
|
88
|
+
-----
|
|
89
|
+
- A sequence of arrays is treated as SUMMABLE phase components (summed in
|
|
90
|
+
df64). A flat python list/tuple of *scalars* is therefore rejected — it
|
|
91
|
+
would otherwise be silently summed into a single phase; pass
|
|
92
|
+
``np.asarray(...)`` for a vector of per-element phases.
|
|
93
|
+
- ``isign`` is normalized by sign only (>=0 -> +1, else -1).
|
|
94
|
+
- NaN/inf phases propagate per element to NaN (no cross-element effect).
|
|
95
|
+
"""
|
|
96
|
+
if isinstance(phases, (list, tuple)):
|
|
97
|
+
comps = [np.asarray(v, dtype=np.float64) for v in phases]
|
|
98
|
+
else:
|
|
99
|
+
comps = [np.asarray(phases, dtype=np.float64)]
|
|
100
|
+
if len(comps) < 1:
|
|
101
|
+
raise ValueError("expi: at least one phase component is required")
|
|
102
|
+
shape = comps[0].shape
|
|
103
|
+
if any(c.shape != shape for c in comps):
|
|
104
|
+
raise ValueError("expi: all phase components must share one shape")
|
|
105
|
+
if len(comps) > 1 and shape == ():
|
|
106
|
+
raise ValueError(
|
|
107
|
+
"expi: got a multi-element sequence of scalars, which would be "
|
|
108
|
+
"summed into a single phase. Pass np.asarray(...) for a vector of "
|
|
109
|
+
"per-element phases, or 1-D arrays as summable phase components.")
|
|
110
|
+
isign = +1 if isign >= 0 else -1
|
|
111
|
+
P = int(np.prod(shape)) if shape else 1
|
|
112
|
+
ncomp = len(comps)
|
|
113
|
+
kern = _expi_cache.get(ncomp)
|
|
114
|
+
if kern is None:
|
|
115
|
+
kern = _expi_cache[ncomp] = _build_expi_kernel(ncomp)
|
|
116
|
+
cst = np.zeros(4, dtype=np.float32)
|
|
117
|
+
cst[0] = np.float32(2.0 * PI)
|
|
118
|
+
cst[1] = np.float32(2.0 * PI - np.float64(cst[0]))
|
|
119
|
+
cst[2] = np.float32(1.0 / (2.0 * PI))
|
|
120
|
+
cst[3] = np.float32(isign)
|
|
121
|
+
ins = []
|
|
122
|
+
with np.errstate(invalid="ignore"): # inf-in -> nan-out, no warning
|
|
123
|
+
for c in comps:
|
|
124
|
+
cf = c.ravel()
|
|
125
|
+
hi = cf.astype(np.float32)
|
|
126
|
+
lo = (cf - hi).astype(np.float32)
|
|
127
|
+
ins += [mx.array(hi), mx.array(lo)]
|
|
128
|
+
ins += [mx.array(cst), mx.array(np.array([P], dtype=np.int32))]
|
|
129
|
+
out = kern(inputs=ins, output_shapes=[(2 * P,)],
|
|
130
|
+
output_dtypes=[mx.float32],
|
|
131
|
+
grid=(P, 1, 1), threadgroup=(256, 1, 1))[0]
|
|
132
|
+
res = mx.reshape(mx.view(out, dtype=mx.complex64), shape)
|
|
133
|
+
mx.eval(res)
|
|
134
|
+
return np.array(res) if return_np else res
|