mlx-nufft 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mlx_nufft/__init__.py ADDED
@@ -0,0 +1,48 @@
1
+ """mlx-nufft: non-uniform FFTs on Apple GPUs (Metal, via MLX).
2
+
3
+ Drop-in mirror of the `finufft` Python package's interface:
4
+
5
+ import mlx_nufft as finufft
6
+ fk = finufft.nufft2d1(x, y, c, (N1, N2), eps=1e-6) # types 1/2/3, dims 1/2/3
7
+ plan = finufft.Plan(1, (N1, N2), eps=1e-6)
8
+ plan.setpts(x, y)
9
+ fk = plan.execute(c)
10
+
11
+ plus the native plan classes:
12
+
13
+ from mlx_nufft import Type3Plan
14
+ plan = Type3Plan((x1, x2, x3), (s1, s2, s3), eps=1e-5, isign=+1)
15
+ f = plan.execute(c) # f[k] = sum_j c[j] exp(i*isign*s_k.x_j)
16
+
17
+ Precision model (see REPORT.md): fp32 GPU pipeline with the precision-
18
+ critical setup (coordinate rescale, pre/post phases) in fp64 at plan time
19
+ ('crit64', the default). Plans cache all geometry-dependent state, so
20
+ fixed-geometry workloads pay setup once and execute() per call.
21
+ """
22
+
23
+ from .gpu_t3 import GpuT3Plan as Type3Plan
24
+ from .types12 import Type1Plan, Type2Plan
25
+ from .nd import Type1PlanND, Type2PlanND
26
+ from .dfmath import expi, EXPI_MAX_PHASE
27
+ from .sizing import kernel_params, next235even
28
+ from .api import (Plan,
29
+ nufft1d1, nufft1d2, nufft1d3,
30
+ nufft2d1, nufft2d2, nufft2d3,
31
+ nufft3d1, nufft3d2, nufft3d3)
32
+
33
+
34
+ def vkfft_available():
35
+ """True if the optional VkFFT-Metal FFT backend (fft_backend='vkfft') is
36
+ built and loadable — see vkfft_bridge/build.sh."""
37
+ from . import vkfft_backend
38
+ return vkfft_backend.available()
39
+
40
+
41
+ __version__ = "0.1.1"
42
+ __all__ = ["Type3Plan", "Type1Plan", "Type2Plan",
43
+ "Type1PlanND", "Type2PlanND", "Plan",
44
+ "nufft1d1", "nufft1d2", "nufft1d3",
45
+ "nufft2d1", "nufft2d2", "nufft2d3",
46
+ "nufft3d1", "nufft3d2", "nufft3d3",
47
+ "kernel_params", "next235even", "vkfft_available",
48
+ "expi", "EXPI_MAX_PHASE"]
mlx_nufft/api.py ADDED
@@ -0,0 +1,388 @@
1
+ """finufft-compatible API for mlx_nufft.
2
+
3
+ Drop-in mirror of the `finufft` Python package's interface, so existing
4
+ callers switch with one import line:
5
+
6
+ import mlx_nufft as finufft
7
+ fk = finufft.nufft2d1(x, y, c, (N1, N2), eps=1e-6)
8
+ plan = finufft.Plan(1, (N1, N2), n_trans=8, eps=1e-6)
9
+ plan.setpts(x, y)
10
+ fk = plan.execute(c)
11
+
12
+ Semantics mirrored from finufft 2.x:
13
+ - mode boxes are modeord=0: k_d integer in [-(N_d//2), (N_d-1)//2],
14
+ even or odd N_d;
15
+ - isign: non-negative means +i in the exponential (type-1/3 default +1,
16
+ type-2 default -1);
17
+ - multi-vector inputs: leading n_trans axis on strengths/mode arrays;
18
+ - out= arrays are filled in place when supplied;
19
+ - x (etc.) in [-pi, pi), folded otherwise.
20
+
21
+ Differences (documented, not silent):
22
+ - computation is the validated fp32 GPU pipeline with fp64-critical
23
+ setup ('crit64'); requesting eps below 1e-6 clamps to 1e-6 with a
24
+ warning (the fp32 accuracy envelope: see ACCEPTANCE.md);
25
+ - complex128 inputs are accepted and returned as complex128, but the
26
+ transform itself is fp32-grade;
27
+ - modeord=1 (FFT ordering) is not implemented (raises);
28
+ - 1D/2D type 3 currently run as degenerate slices of the validated 3D
29
+ type-3 kernel (functional; not speed-tuned).
30
+ """
31
+
32
+ import warnings
33
+
34
+ import numpy as np
35
+
36
+ from .nd import Type1PlanND, Type2PlanND
37
+ from .gpu_t3 import GpuT3Plan
38
+
39
+ _EPS_FLOOR = 1e-6
40
+ _IGNORED_OPTS = {
41
+ "nthreads", "debug", "spread_debug", "showwarn", "fftw", "spread_sort",
42
+ "spread_kerevalmeth", "spread_kerpad", "chkbnds", "maxbatchsize",
43
+ "spread_thread", "spread_nthr_atomic", "spread_max_sp_size",
44
+ }
45
+
46
+
47
+ def _check_opts(kwargs):
48
+ opts = dict(kwargs)
49
+ if opts.pop("modeord", 0) not in (0,):
50
+ raise NotImplementedError("modeord=1 (FFT ordering) not implemented")
51
+ upsampfac = opts.pop("upsampfac", None)
52
+ if upsampfac in (0, 0.0): # finufft auto sentinel
53
+ upsampfac = None
54
+ prec = opts.pop("prec", "crit64")
55
+ fft_backend = opts.pop("fft_backend", "mlx") # type-3 slab only
56
+ for k in list(opts):
57
+ if k in _IGNORED_OPTS:
58
+ opts.pop(k)
59
+ if opts:
60
+ warnings.warn(f"mlx-nufft: ignoring unknown options {sorted(opts)}")
61
+ return upsampfac, prec, fft_backend
62
+
63
+
64
+ def _norm_eps(eps):
65
+ eps = float(eps)
66
+ if eps < _EPS_FLOOR:
67
+ warnings.warn(
68
+ f"mlx-nufft: eps={eps:g} is below the fp32 pipeline floor; "
69
+ f"clamping to {_EPS_FLOOR:g} (see ACCEPTANCE.md accuracy notes)")
70
+ eps = _EPS_FLOOR
71
+ return eps
72
+
73
+
74
+ def _norm_isign(isign, default):
75
+ if isign is None:
76
+ return default
77
+ return +1 if isign >= 0 else -1
78
+
79
+
80
+ def _out_dtype(*arrays):
81
+ for a in arrays:
82
+ if np.asarray(a).dtype in (np.complex128, np.float64):
83
+ return np.complex128
84
+ return np.complex64
85
+
86
+
87
+ def _vec_shape(data, inner_ndim, inner_shape=None):
88
+ """Split data shape into (n_trans, inner shape); inner_ndim trailing.
89
+ If inner_shape is given, the trailing dims must match it exactly
90
+ (mirrors FINUFFT's strict size checks — no silent truncation)."""
91
+ data = np.asarray(data)
92
+ if data.ndim == inner_ndim:
93
+ out = 1, data[None, ...]
94
+ elif data.ndim == inner_ndim + 1:
95
+ out = data.shape[0], data
96
+ else:
97
+ raise ValueError(f"data must have {inner_ndim} or {inner_ndim + 1} "
98
+ f"dims, got shape {data.shape}")
99
+ if inner_shape is not None and out[1].shape[1:] != tuple(inner_shape):
100
+ raise ValueError(f"data inner shape {out[1].shape[1:]} must be "
101
+ f"{tuple(inner_shape)}")
102
+ return out
103
+
104
+
105
+ def _fill_out(out, res, dtype):
106
+ res = res.astype(dtype, copy=False)
107
+ if out is not None:
108
+ # exact shape, or the (1, ...) stacked form when n_trans == 1
109
+ if out.shape != res.shape and out.shape != (1,) + res.shape \
110
+ and (1,) + out.shape != res.shape:
111
+ raise ValueError(f"out.shape {out.shape} does not match result "
112
+ f"shape {res.shape}")
113
+ np.copyto(out, res.reshape(out.shape))
114
+ return out
115
+ return res
116
+
117
+
118
+ def _embed3(arrs, dim):
119
+ """Zero-pad a dim<3 coordinate tuple to 3 components for GpuT3Plan."""
120
+ arrs = [np.asarray(a, dtype=np.float64).ravel() for a in arrs]
121
+ z = np.zeros(arrs[0].size)
122
+ return tuple(arrs) + (z,) * (3 - dim)
123
+
124
+
125
+ def _modes_tuple(n_modes, dim, out, out_offset=0):
126
+ if n_modes is None:
127
+ if out is None:
128
+ raise ValueError("either n_modes or out must be supplied")
129
+ shape = out.shape[out_offset:]
130
+ if len(shape) != dim:
131
+ raise ValueError(f"out shape {out.shape} does not match dim {dim}")
132
+ return tuple(int(n) for n in shape)
133
+ if np.isscalar(n_modes):
134
+ return (int(n_modes),) * dim
135
+ return tuple(int(n) for n in n_modes)
136
+
137
+
138
+ # ---------------------------------------------------------------------------
139
+ # type 1: nonuniform -> uniform
140
+
141
+
142
+ def _warn_no_vkfft(fft_backend, what):
143
+ if fft_backend != "mlx":
144
+ warnings.warn(f"mlx-nufft: fft_backend={fft_backend!r} applies only "
145
+ f"to 3D type-3 (slab); ignored for {what}")
146
+
147
+
148
+ def _nufft_t1(dim, coords, c, n_modes, out, eps, isign, kwargs):
149
+ upsampfac, prec, fft_backend = _check_opts(kwargs)
150
+ _warn_no_vkfft(fft_backend, "type-1")
151
+ eps = _norm_eps(eps)
152
+ isign = _norm_isign(isign, +1)
153
+ dtype = _out_dtype(c)
154
+ M = np.asarray(coords[0]).size
155
+ n_tr, cv = _vec_shape(c, 1, inner_shape=(M,))
156
+ if out is not None and n_modes is None:
157
+ n_modes = _modes_tuple(None, dim, out, out_offset=(1 if n_tr > 1 else 0))
158
+ N = _modes_tuple(n_modes, dim, out)
159
+ kw = {} if upsampfac is None else {"upsampfac": upsampfac}
160
+ plan = Type1PlanND(coords, N, eps=eps, isign=isign, prec=prec, **kw)
161
+ res = np.stack([plan.execute(cv[t]) for t in range(n_tr)])
162
+ if n_tr == 1 and (np.asarray(c).ndim == 1):
163
+ res = res[0]
164
+ return _fill_out(out, res, dtype)
165
+
166
+
167
+ def nufft1d1(x, c, n_modes=None, out=None, eps=1e-6, isign=1, **kwargs):
168
+ """1D type-1: f[k] = sum_j c[j] exp(+/-i k x(j))."""
169
+ return _nufft_t1(1, (x,), c, n_modes, out, eps, isign, kwargs)
170
+
171
+
172
+ def nufft2d1(x, y, c, n_modes=None, out=None, eps=1e-6, isign=1, **kwargs):
173
+ """2D type-1: f[k1,k2] = sum_j c[j] exp(+/-i (k1 x(j) + k2 y(j)))."""
174
+ return _nufft_t1(2, (x, y), c, n_modes, out, eps, isign, kwargs)
175
+
176
+
177
+ def nufft3d1(x, y, z, c, n_modes=None, out=None, eps=1e-6, isign=1, **kwargs):
178
+ """3D type-1: f[k1,k2,k3] = sum_j c[j] exp(+/-i k . x_j)."""
179
+ return _nufft_t1(3, (x, y, z), c, n_modes, out, eps, isign, kwargs)
180
+
181
+
182
+ # ---------------------------------------------------------------------------
183
+ # type 2: uniform -> nonuniform
184
+
185
+
186
+ def _nufft_t2(dim, coords, f, out, eps, isign, kwargs):
187
+ upsampfac, prec, fft_backend = _check_opts(kwargs)
188
+ _warn_no_vkfft(fft_backend, "type-2")
189
+ eps = _norm_eps(eps)
190
+ isign = _norm_isign(isign, -1)
191
+ dtype = _out_dtype(f)
192
+ n_tr, fv = _vec_shape(f, dim)
193
+ N = fv.shape[1:]
194
+ kw = {} if upsampfac is None else {"upsampfac": upsampfac}
195
+ plan = Type2PlanND(coords, N, eps=eps, isign=isign, prec=prec, **kw)
196
+ res = np.stack([plan.execute(fv[t]) for t in range(n_tr)])
197
+ if n_tr == 1 and (np.asarray(f).ndim == dim):
198
+ res = res[0]
199
+ return _fill_out(out, res, dtype)
200
+
201
+
202
+ def nufft1d2(x, f, out=None, eps=1e-6, isign=-1, **kwargs):
203
+ """1D type-2: c[j] = sum_k f[k] exp(+/-i k x(j))."""
204
+ return _nufft_t2(1, (x,), f, out, eps, isign, kwargs)
205
+
206
+
207
+ def nufft2d2(x, y, f, out=None, eps=1e-6, isign=-1, **kwargs):
208
+ """2D type-2: c[j] = sum_{k1,k2} f[k1,k2] exp(+/-i (k1 x(j) + k2 y(j)))."""
209
+ return _nufft_t2(2, (x, y), f, out, eps, isign, kwargs)
210
+
211
+
212
+ def nufft3d2(x, y, z, f, out=None, eps=1e-6, isign=-1, **kwargs):
213
+ """3D type-2: c[j] = sum_k f[k] exp(+/-i k . x_j)."""
214
+ return _nufft_t2(3, (x, y, z), f, out, eps, isign, kwargs)
215
+
216
+
217
+ # ---------------------------------------------------------------------------
218
+ # type 3: nonuniform -> nonuniform
219
+
220
+
221
+ def _nufft_t3(dim, src, c, trg, out, eps, isign, kwargs):
222
+ upsampfac, prec, fft_backend = _check_opts(kwargs)
223
+ if upsampfac is not None and upsampfac != 1.25:
224
+ warnings.warn("mlx-nufft: type-3 runs the validated sigma=1.25 "
225
+ "pipeline; upsampfac ignored")
226
+ eps = _norm_eps(eps)
227
+ isign = _norm_isign(isign, +1)
228
+ dtype = _out_dtype(c)
229
+ M = np.asarray(src[0]).size
230
+ n_tr, cv = _vec_shape(c, 1, inner_shape=(M,))
231
+ plan = GpuT3Plan(_embed3(src, dim), _embed3(trg, dim),
232
+ eps=eps, isign=isign, prec=prec, fft_backend=fft_backend)
233
+ res = np.stack([plan.execute(cv[t]) for t in range(n_tr)])
234
+ if n_tr == 1 and (np.asarray(c).ndim == 1):
235
+ res = res[0]
236
+ return _fill_out(out, res, dtype)
237
+
238
+
239
+ def nufft1d3(x, c, s, out=None, eps=1e-6, isign=1, **kwargs):
240
+ """1D type-3: f[k] = sum_j c[j] exp(+/-i s[k] x[j])."""
241
+ return _nufft_t3(1, (x,), c, (s,), out, eps, isign, kwargs)
242
+
243
+
244
+ def nufft2d3(x, y, c, s, t, out=None, eps=1e-6, isign=1, **kwargs):
245
+ """2D type-3: f[k] = sum_j c[j] exp(+/-i (s[k] x[j] + t[k] y[j]))."""
246
+ return _nufft_t3(2, (x, y), c, (s, t), out, eps, isign, kwargs)
247
+
248
+
249
+ def nufft3d3(x, y, z, c, s, t, u, out=None, eps=1e-6, isign=1, **kwargs):
250
+ """3D type-3: f[k] = sum_j c[j] exp(+/-i (s,t,u)_k . (x,y,z)_j)."""
251
+ return _nufft_t3(3, (x, y, z), c, (s, t, u), out, eps, isign, kwargs)
252
+
253
+
254
+ # ---------------------------------------------------------------------------
255
+ # Plan interface
256
+
257
+
258
+ class Plan:
259
+ """finufft.Plan-compatible plan/setpts/execute interface.
260
+
261
+ Plan(nufft_type, n_modes_or_dim, n_trans=1, eps=1e-6, isign=None,
262
+ dtype='complex128', **kwargs)
263
+
264
+ For types 1/2, n_modes_or_dim is the mode tuple (dim inferred from its
265
+ length). For type 3 it is the dimension (1, 2 or 3). setpts() builds the
266
+ GPU plan (points are part of plan state, as in cu/FINUFFT); execute()
267
+ runs each of n_trans vectors through the cached plan.
268
+ """
269
+
270
+ def __init__(self, nufft_type, n_modes_or_dim, n_trans=1, eps=1e-6,
271
+ isign=None, dtype="complex128", **kwargs):
272
+ if nufft_type not in (1, 2, 3):
273
+ raise ValueError("nufft_type must be 1, 2 or 3")
274
+ self.type = int(nufft_type)
275
+ self.n_trans = int(n_trans)
276
+ self.eps = _norm_eps(eps)
277
+ self.isign = _norm_isign(isign, -1 if self.type == 2 else +1)
278
+ self.dtype = np.dtype(dtype)
279
+ if self.dtype not in (np.complex64, np.complex128):
280
+ raise ValueError("dtype must be complex64 or complex128")
281
+ self._upsampfac, self._prec, self._fft_backend = _check_opts(kwargs)
282
+ if self._fft_backend != "mlx" and self.type != 3:
283
+ _warn_no_vkfft(self._fft_backend, f"type-{self.type}")
284
+ self._fft_backend = "mlx"
285
+ if self.type == 3:
286
+ self.dim = int(n_modes_or_dim)
287
+ self.n_modes = None
288
+ else:
289
+ if np.isscalar(n_modes_or_dim):
290
+ n_modes_or_dim = (n_modes_or_dim,)
291
+ self.n_modes = tuple(int(n) for n in n_modes_or_dim)
292
+ self.dim = len(self.n_modes)
293
+ if self.dim not in (1, 2, 3):
294
+ raise ValueError("dim must be 1, 2 or 3")
295
+ self._plan = None
296
+ self._adjoint = None
297
+ self._n_targets = None
298
+
299
+ def setpts(self, x=None, y=None, z=None, s=None, t=None, u=None):
300
+ coords = [v for v in (x, y, z) if v is not None]
301
+ if len(coords) != self.dim:
302
+ raise ValueError(f"expected {self.dim} coordinate arrays, "
303
+ f"got {len(coords)}")
304
+ kw = {} if self._upsampfac is None else {"upsampfac": self._upsampfac}
305
+ self._adjoint = None
306
+ if self.type == 1:
307
+ self._plan = Type1PlanND(tuple(coords), self.n_modes,
308
+ eps=self.eps, isign=self.isign,
309
+ prec=self._prec, **kw)
310
+ elif self.type == 2:
311
+ self._plan = Type2PlanND(tuple(coords), self.n_modes,
312
+ eps=self.eps, isign=self.isign,
313
+ prec=self._prec, **kw)
314
+ else:
315
+ if self._upsampfac is not None and self._upsampfac != 1.25:
316
+ warnings.warn("mlx-nufft: type-3 runs the validated "
317
+ "sigma=1.25 pipeline; upsampfac ignored")
318
+ trg = [v for v in (s, t, u) if v is not None]
319
+ if len(trg) != self.dim:
320
+ raise ValueError(f"expected {self.dim} target arrays, "
321
+ f"got {len(trg)}")
322
+ self._plan = GpuT3Plan(_embed3(coords, self.dim),
323
+ _embed3(trg, self.dim),
324
+ eps=self.eps, isign=self.isign,
325
+ prec=self._prec, fft_backend=self._fft_backend)
326
+ self._n_targets = np.asarray(trg[0]).size
327
+ self._coords = [np.asarray(v) for v in coords]
328
+ self._targets = None if self.type != 3 else \
329
+ [np.asarray(v) for v in (s, t, u) if v is not None]
330
+
331
+ def execute(self, data, out=None):
332
+ if self._plan is None:
333
+ raise RuntimeError("setpts() must be called before execute()")
334
+ if self.type == 2:
335
+ inner, ishape = self.dim, self.n_modes
336
+ else:
337
+ inner, ishape = 1, (self._plan.P,)
338
+ n_tr, dv = _vec_shape(data, inner, inner_shape=ishape)
339
+ if n_tr != self.n_trans:
340
+ raise ValueError(f"data has {n_tr} vectors, plan has "
341
+ f"n_trans={self.n_trans}")
342
+ res = np.stack([self._plan.execute(dv[k]) for k in range(self.n_trans)])
343
+ if self.n_trans == 1 and np.asarray(data).ndim == inner:
344
+ res = res[0]
345
+ return _fill_out(out, res, self.dtype)
346
+
347
+ def execute_adjoint(self, data, out=None):
348
+ """Apply the adjoint of the planned transform (finufft 2.5 API).
349
+
350
+ Type-1 plan adjoint maps modes -> points; type-2 adjoint maps
351
+ points -> modes; type-3 adjoint maps targets -> sources. Implemented
352
+ as the sibling transform with isign negated (the exact adjoint of
353
+ the NUFFT matrix; the validated type-3 adjoint identity)."""
354
+ if self._plan is None:
355
+ raise RuntimeError("setpts() must be called before "
356
+ "execute_adjoint()")
357
+ if self._adjoint is None:
358
+ kw = {} if self._upsampfac is None \
359
+ else {"upsampfac": self._upsampfac}
360
+ if self.type == 1:
361
+ self._adjoint = Type2PlanND(tuple(self._coords), self.n_modes,
362
+ eps=self.eps, isign=-self.isign,
363
+ prec=self._prec, **kw)
364
+ elif self.type == 2:
365
+ self._adjoint = Type1PlanND(tuple(self._coords), self.n_modes,
366
+ eps=self.eps, isign=-self.isign,
367
+ prec=self._prec, **kw)
368
+ else:
369
+ self._adjoint = GpuT3Plan(_embed3(self._targets, self.dim),
370
+ _embed3(self._coords, self.dim),
371
+ eps=self.eps, isign=-self.isign,
372
+ prec=self._prec,
373
+ fft_backend=self._fft_backend)
374
+ if self.type == 1:
375
+ inner, ishape = self.dim, self.n_modes
376
+ elif self.type == 2:
377
+ inner, ishape = 1, (self._plan.P,)
378
+ else:
379
+ inner, ishape = 1, (self._n_targets,)
380
+ n_tr, dv = _vec_shape(data, inner, inner_shape=ishape)
381
+ if n_tr != self.n_trans:
382
+ raise ValueError(f"data has {n_tr} vectors, plan has "
383
+ f"n_trans={self.n_trans}")
384
+ res = np.stack([self._adjoint.execute(dv[k])
385
+ for k in range(self.n_trans)])
386
+ if self.n_trans == 1 and np.asarray(data).ndim == inner:
387
+ res = res[0]
388
+ return _fill_out(out, res, self.dtype)
mlx_nufft/dfmath.py ADDED
@@ -0,0 +1,134 @@
1
+ """General double-single (df64) GPU math primitives.
2
+
3
+ These expose the extended-precision machinery the NUFFT plans use internally
4
+ (df64 phase accumulation + reduction mod 2pi, then f32 transcendentals) as
5
+ standalone tools for callers that hit the same f32-precision wall — most
6
+ commonly a large-magnitude fp64 phase that cannot be reduced mod 2pi in f32
7
+ but whose cos/sin you want to evaluate on the GPU.
8
+
9
+ The reduction is identical to the prephase inside GpuT3Plan.set_sources: form
10
+ the phase in df64, k = rint(phi/2pi), then phi - k*2pi in df64 (the product
11
+ k*2pi is exact via two_prod), and f32 cos/sin of the O(1) residual. Valid while
12
+ the integer quotient phi/2pi stays f32-exact, i.e. |phi| <~ 2^24 * 2pi ~ 1.05e8
13
+ radians; beyond that the residual grows and accuracy degrades gracefully.
14
+ """
15
+
16
+ import numpy as np
17
+ import mlx.core as mx
18
+
19
+ from .gpu_t3 import _DF64_HDR
20
+
21
+ PI = np.pi
22
+
23
+ # valid magnitude ceiling: |phi| where rint(phi/2pi) stays an exact f32 integer
24
+ EXPI_MAX_PHASE = float(2 ** 24 * 2.0 * PI) # ~1.054e8 radians
25
+
26
+ _expi_cache = {}
27
+
28
+
29
+ def _build_expi_kernel(ncomp):
30
+ """e^{i*isign*phi} for phi = sum of `ncomp` df64 phase components:
31
+ accumulate in df64, reduce mod 2pi (k = rint(phi/2pi); phi - k*2pi in
32
+ df64, k*2pi exact via two_prod), then f32 cos/sin of the O(1) residual.
33
+
34
+ consts: [2pi_hi, 2pi_lo, 1/2pi, isign]."""
35
+ acc = " df64 ph = df_make(0.0f, 0.0f);\n"
36
+ for c in range(ncomp):
37
+ acc += f" ph = df_add(ph, df_make(ph_hi{c}[j], ph_lo{c}[j]));\n"
38
+ src = f"""
39
+ uint j = thread_position_in_grid.x;
40
+ if (j >= (uint)P0[0]) return;
41
+ {acc} float k = metal::rint(ph.hi * cst[2]); // ph.hi / 2pi
42
+ df64 red = df_add(ph, df_mul(df_make(-k, 0.0f),
43
+ df_make(cst[0], cst[1]))); // ph - k*2pi (df64)
44
+ float ang = cst[3] * (red.hi + red.lo);
45
+ out[2*j] = metal::precise::cos(ang);
46
+ out[2*j+1] = metal::precise::sin(ang);
47
+ """
48
+ innames = []
49
+ for c in range(ncomp):
50
+ innames += [f"ph_hi{c}", f"ph_lo{c}"]
51
+ innames += ["cst", "P0"]
52
+ return mx.fast.metal_kernel(
53
+ name=f"expi_df64_n{ncomp}", input_names=innames,
54
+ output_names=["out"], header=_DF64_HDR, source=src)
55
+
56
+
57
+ def expi(phases, isign=1, return_np=False):
58
+ """Compute e^{i*isign*phi} on the GPU for a large-magnitude fp64 phase via
59
+ a double-single (df64) reduction mod 2pi followed by f32 cos/sin — the
60
+ general form of the prephase machinery inside GpuT3Plan.set_sources.
61
+
62
+ Use this whenever an fp64 phase is too large to reduce in f32 (f32 loses
63
+ all fractional bits by |phi| ~ 1e3) but you want the cos/sin on the GPU.
64
+
65
+ Parameters
66
+ ----------
67
+ phases : fp64 array, or a sequence of fp64 arrays
68
+ Either the phase phi directly, or a small set of phase components of
69
+ identical shape that are summed *in df64* to form phi = sum_k phases[k]
70
+ (more accurate than an fp64 host sum, and keeps the sum on-GPU).
71
+ isign : int
72
+ +1 (default) or -1; computes e^{i*isign*phi}.
73
+ return_np : bool
74
+ If True, copy the result to a numpy complex64 array; otherwise return
75
+ the device mx.array (complex64).
76
+
77
+ Returns
78
+ -------
79
+ e^{i*isign*phi} as complex64, same shape as the input.
80
+
81
+ Accuracy: matches an fp64 host reference (np.exp) to ~1e-6 per element for
82
+ |phi| up to EXPI_MAX_PHASE (~1.05e8 rad) — approaching ~2e-6 right at the
83
+ ceiling, ~2e-7 for |phi| <~ 1e7. Beyond the ceiling the integer quotient
84
+ phi/2pi is no longer f32-exact and accuracy degrades gracefully (~linear in
85
+ |phi|, no cliff).
86
+
87
+ Notes
88
+ -----
89
+ - A sequence of arrays is treated as SUMMABLE phase components (summed in
90
+ df64). A flat python list/tuple of *scalars* is therefore rejected — it
91
+ would otherwise be silently summed into a single phase; pass
92
+ ``np.asarray(...)`` for a vector of per-element phases.
93
+ - ``isign`` is normalized by sign only (>=0 -> +1, else -1).
94
+ - NaN/inf phases propagate per element to NaN (no cross-element effect).
95
+ """
96
+ if isinstance(phases, (list, tuple)):
97
+ comps = [np.asarray(v, dtype=np.float64) for v in phases]
98
+ else:
99
+ comps = [np.asarray(phases, dtype=np.float64)]
100
+ if len(comps) < 1:
101
+ raise ValueError("expi: at least one phase component is required")
102
+ shape = comps[0].shape
103
+ if any(c.shape != shape for c in comps):
104
+ raise ValueError("expi: all phase components must share one shape")
105
+ if len(comps) > 1 and shape == ():
106
+ raise ValueError(
107
+ "expi: got a multi-element sequence of scalars, which would be "
108
+ "summed into a single phase. Pass np.asarray(...) for a vector of "
109
+ "per-element phases, or 1-D arrays as summable phase components.")
110
+ isign = +1 if isign >= 0 else -1
111
+ P = int(np.prod(shape)) if shape else 1
112
+ ncomp = len(comps)
113
+ kern = _expi_cache.get(ncomp)
114
+ if kern is None:
115
+ kern = _expi_cache[ncomp] = _build_expi_kernel(ncomp)
116
+ cst = np.zeros(4, dtype=np.float32)
117
+ cst[0] = np.float32(2.0 * PI)
118
+ cst[1] = np.float32(2.0 * PI - np.float64(cst[0]))
119
+ cst[2] = np.float32(1.0 / (2.0 * PI))
120
+ cst[3] = np.float32(isign)
121
+ ins = []
122
+ with np.errstate(invalid="ignore"): # inf-in -> nan-out, no warning
123
+ for c in comps:
124
+ cf = c.ravel()
125
+ hi = cf.astype(np.float32)
126
+ lo = (cf - hi).astype(np.float32)
127
+ ins += [mx.array(hi), mx.array(lo)]
128
+ ins += [mx.array(cst), mx.array(np.array([P], dtype=np.int32))]
129
+ out = kern(inputs=ins, output_shapes=[(2 * P,)],
130
+ output_dtypes=[mx.float32],
131
+ grid=(P, 1, 1), threadgroup=(256, 1, 1))[0]
132
+ res = mx.reshape(mx.view(out, dtype=mx.complex64), shape)
133
+ mx.eval(res)
134
+ return np.array(res) if return_np else res