pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
@@ -0,0 +1,734 @@
1
+ """ Utility functions for cupy backend.
2
+
3
+ The functions spline_filter, _prepad_for_spline_filter, _filter_input,
4
+ _get_coord_affine_batched and affine_transform are largely copied from
5
+ cupyx.scipy.ndimage which operates under the following license
6
+
7
+ Copyright (c) 2015 Preferred Infrastructure, Inc.
8
+ Copyright (c) 2015 Preferred Networks, Inc.
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in
18
+ all copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
26
+ THE SOFTWARE.
27
+
28
+ I have since extended the functionality of the cupyx.scipy.ndimage functions
29
+ in question to support batched inputs.
30
+
31
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
32
+ """
33
+
34
+ import numpy
35
+ import cupy
36
+
37
+ from cupy import _core
38
+ from cupyx.scipy.ndimage import (
39
+ _util,
40
+ _interp_kernels,
41
+ _interpolation,
42
+ spline_filter1d,
43
+ _spline_prefilter_core,
44
+ _spline_kernel_weights,
45
+ )
46
+
47
+ spline_weights_inline = _spline_kernel_weights.spline_weights_inline
48
+
49
+
50
+ math_constants_preamble = r"""
51
+ // workaround for HIP: line begins with #include
52
+ #include <cupy/math_constants.h>
53
+ """
54
+
55
+
56
+ def _prepad_for_spline_filter(input, mode, cval, batched=True):
57
+ """
58
+ Prepad the input array for spline filtering.
59
+
60
+ Parameters
61
+ ----------
62
+ input : CupyArray
63
+ The input array to be padded.
64
+ mode : str
65
+ Determines how input points outside the boundaries are handled.
66
+ cval : scalar
67
+ Constant value to use for padding if mode is 'grid-constant'.
68
+ batched : bool, optional
69
+ Whether the input has a leading batch dimension, by default False.
70
+
71
+ Returns
72
+ -------
73
+ padded : CupyArray
74
+ The padded input array.
75
+ npad : int or tuple of tuples
76
+ The amount of padding applied.
77
+ """
78
+ if mode in ["nearest", "grid-constant"]:
79
+ # empirical factor chosen by SciPy
80
+ npad = tuple(
81
+ (0, 0) if batched and i == 0 else (12, 12) for i in range(input.ndim)
82
+ )
83
+ if mode == "grid-constant":
84
+ kwargs = dict(mode="constant", constant_values=cval)
85
+ else:
86
+ kwargs = dict(mode="edge")
87
+ padded = cupy.pad(input, npad, **kwargs)
88
+ else:
89
+ npad = 0
90
+ padded = input
91
+ return padded, npad
92
+
93
+
94
+ def spline_filter(input, order=3, output=cupy.float64, mode="mirror", batched=True):
95
+ """Multidimensional spline filter.
96
+
97
+ Parameters
98
+ ----------
99
+ input : CupyArray
100
+ The input array.
101
+ order : int, optional
102
+ The order of the spline interpolation, default is 3. Must be in the range 0-5.
103
+ output : CupyArray or dtype, optional
104
+ The array in which to place the output, or the dtype of the returned array.
105
+ mode : str, optional
106
+ Determines how input points outside the boundaries are handled.
107
+ batched : bool, optional
108
+ Whether the input has a leading batch dimension. Default is False.
109
+
110
+ Returns
111
+ -------
112
+ CupyArray
113
+ The result of prefiltering the input.
114
+
115
+ See Also
116
+ --------
117
+ :obj:`scipy.ndimage.spline_filter1d`
118
+ """
119
+ if order < 2 or order > 5:
120
+ raise RuntimeError("spline order not supported")
121
+
122
+ x = input
123
+ ibatch = int(batched)
124
+ temp, data_dtype, output_dtype = _interpolation._get_spline_output(x, output)
125
+ if order not in [0, 1] and input.ndim > 0:
126
+ for axis in range(ibatch, x.ndim):
127
+ spline_filter1d(x, order, axis, output=temp, mode=mode)
128
+ x = temp
129
+ if isinstance(output, cupy.ndarray):
130
+ _core.elementwise_copy(temp, output)
131
+ else:
132
+ output = temp
133
+ if output.dtype != output_dtype:
134
+ output = output.astype(output_dtype)
135
+ return output
136
+
137
+
138
+ def _filter_input(image, prefilter, mode, cval, order, batched=True):
139
+ """
140
+ Perform spline prefiltering on the input image if requested.
141
+
142
+ Parameters
143
+ ----------
144
+ image : CupyArray
145
+ The input image to be filtered.
146
+ prefilter : bool
147
+ Whether to apply prefiltering or not.
148
+ mode : str
149
+ The boundary mode to use. See `cupy.pad` for details.
150
+ cval : scalar
151
+ Value to fill past edges of input if mode is 'constant'.
152
+ order : int
153
+ The order of the spline interpolation. Must be in the range 0-5.
154
+ batched : bool, optional
155
+ Whether the input has a leading batch dimension. Default is False.
156
+
157
+ Returns
158
+ -------
159
+ filtered_image : ndarray
160
+ The filtered image as a contiguous array.
161
+ npad : int
162
+ The amount of padding applied at each edge of the array.
163
+ """
164
+ if not prefilter or order < 2:
165
+ return (cupy.ascontiguousarray(image), 0)
166
+ padded, npad = _prepad_for_spline_filter(image, mode, cval, batched=batched)
167
+ float_dtype = cupy.promote_types(image.dtype, cupy.float32)
168
+ filtered = spline_filter(
169
+ padded, order, output=float_dtype, mode=mode, batched=batched
170
+ )
171
+ return cupy.ascontiguousarray(filtered), npad
172
+
173
+
174
+ def _get_coord_affine_batched(ndim, nprepad=0):
175
+ """
176
+ Compute target coordinate based on a homogeneous transformation matrix.
177
+
178
+ Parameters
179
+ ----------
180
+ ndim : int
181
+ Number of dimensions for the coordinate system.
182
+ nprepad : int, optional
183
+ Number of elements to prepad, by default 0.
184
+
185
+ Returns
186
+ -------
187
+ list of str
188
+ A list of string operations representing the coordinate computation.
189
+
190
+ Notes
191
+ -----
192
+ This function assumes the following variables have been initialized on the device:
193
+
194
+ - mat (array): Array containing the (ndim, ndim + 1) transform matrix.
195
+ - in_coords (array): Coordinates of the input.
196
+
197
+ For example, in 2D
198
+
199
+ c_0 = in_coords[0]
200
+ c_1 = mat[0] * in_coords[1] + mat[1] * in_coords[2] + mat[2]
201
+ c_2 = mat[3] * in_coords[1] + mat[4] * in_coords[2] + mat[5]
202
+
203
+ """
204
+ ops = []
205
+ pre = f" + (W){nprepad}" if nprepad > 0 else ""
206
+
207
+ batched, ibatched = True, 1
208
+ ncol = ndim + 1 - ibatched
209
+ for j in range(ndim):
210
+ if batched:
211
+ ops.append(
212
+ f"""
213
+ W c_{j} = (W)in_coord[{j}];"""
214
+ )
215
+ batched = False
216
+ continue
217
+ ops.append(
218
+ f"""
219
+ W c_{j} = (W)0.0;"""
220
+ )
221
+ j_batch = j - ibatched
222
+ for k in range(ibatched, ndim):
223
+ ops.append(
224
+ f"""
225
+ c_{j} += mat[{ncol * j_batch + k - ibatched}] * (W)in_coord[{k}];"""
226
+ )
227
+ ops.append(
228
+ f"""
229
+ c_{j} += mat[{ncol * j_batch + ndim - ibatched}]{pre};"""
230
+ )
231
+ return ops
232
+
233
+
234
+ def _generate_interp_custom(
235
+ coord_func,
236
+ ndim,
237
+ large_int,
238
+ yshape,
239
+ mode,
240
+ cval,
241
+ order,
242
+ name="",
243
+ integer_output=False,
244
+ nprepad=0,
245
+ omit_in_coord=False,
246
+ batched=False,
247
+ ):
248
+ """
249
+ Args:
250
+ coord_func (function): generates code to do the coordinate
251
+ transformation. See for example, `_get_coord_shift`.
252
+ ndim (int): The number of dimensions.
253
+ large_int (bool): If true use Py_ssize_t instead of int for indexing.
254
+ yshape (tuple): Shape of the output array.
255
+ mode (str): Signal extension mode to use at the array boundaries
256
+ cval (float): constant value used when `mode == 'constant'`.
257
+ name (str): base name for the interpolation kernel
258
+ integer_output (bool): boolean indicating whether the output has an
259
+ integer type.
260
+ nprepad (int): integer indicating the amount of prepadding at the
261
+ boundaries.
262
+
263
+ Returns:
264
+ operation (str): code body for the ElementwiseKernel
265
+ name (str): name for the ElementwiseKernel
266
+ """
267
+
268
+ ops = []
269
+ internal_dtype = "double" if integer_output else "Y"
270
+ ops.append(f"{internal_dtype} out = 0.0;")
271
+
272
+ if large_int:
273
+ uint_t = "size_t"
274
+ int_t = "ptrdiff_t"
275
+ else:
276
+ uint_t = "unsigned int"
277
+ int_t = "int"
278
+
279
+ # determine strides for x along each axis
280
+ for j in range(ndim):
281
+ ops.append(f"const {int_t} xsize_{j} = x.shape()[{j}];")
282
+ ops.append(f"const {uint_t} sx_{ndim - 1} = 1;")
283
+ for j in range(ndim - 1, 0, -1):
284
+ ops.append(f"const {uint_t} sx_{j - 1} = sx_{j} * xsize_{j};")
285
+
286
+ if not omit_in_coord:
287
+ # create in_coords array to store the unraveled indices
288
+ ops.append(_interp_kernels._unravel_loop_index(yshape, uint_t))
289
+
290
+ # compute the transformed (target) coordinates, c_j
291
+ ops = ops + coord_func(ndim, nprepad)
292
+
293
+ if cval is numpy.nan:
294
+ cval = "(Y)CUDART_NAN"
295
+ elif cval == numpy.inf:
296
+ cval = "(Y)CUDART_INF"
297
+ elif cval == -numpy.inf:
298
+ cval = "(Y)(-CUDART_INF)"
299
+ else:
300
+ cval = f"({internal_dtype}){cval}"
301
+
302
+ if mode == "constant":
303
+ # use cval if coordinate is outside the bounds of x
304
+ _cond = " || ".join(
305
+ [f"(c_{j} < 0) || (c_{j} > xsize_{j} - 1)" for j in range(ndim)]
306
+ )
307
+ ops.append(
308
+ f"""
309
+ if ({_cond})
310
+ {{
311
+ out = {cval};
312
+ }}
313
+ else
314
+ {{"""
315
+ )
316
+
317
+ if order == 0:
318
+ if mode == "wrap":
319
+ ops.append("double dcoord;") # mode 'wrap' requires this to work
320
+ for j in range(ndim):
321
+ # determine nearest neighbor
322
+ if mode == "wrap":
323
+ ops.append(
324
+ f"""
325
+ dcoord = c_{j};"""
326
+ )
327
+ else:
328
+ ops.append(
329
+ f"""
330
+ {int_t} cf_{j} = ({int_t})floor((double)c_{j} + 0.5);"""
331
+ )
332
+
333
+ # handle boundary
334
+ if mode != "constant":
335
+ if mode == "wrap":
336
+ ixvar = "dcoord"
337
+ float_ix = True
338
+ else:
339
+ ixvar = f"cf_{j}"
340
+ float_ix = False
341
+ ops.append(
342
+ _util._generate_boundary_condition_ops(
343
+ mode, ixvar, f"xsize_{j}", int_t, float_ix
344
+ )
345
+ )
346
+ if mode == "wrap":
347
+ ops.append(
348
+ f"""
349
+ {int_t} cf_{j} = ({int_t})floor(dcoord + 0.5);"""
350
+ )
351
+
352
+ # sum over ic_j will give the raveled coordinate in the input
353
+ ops.append(
354
+ f"""
355
+ {int_t} ic_{j} = cf_{j} * sx_{j};"""
356
+ )
357
+ _coord_idx = " + ".join([f"ic_{j}" for j in range(ndim)])
358
+ if mode == "grid-constant":
359
+ _cond = " || ".join([f"(ic_{j} < 0)" for j in range(ndim)])
360
+ ops.append(
361
+ f"""
362
+ if ({_cond}) {{
363
+ out = {cval};
364
+ }} else {{
365
+ out = ({internal_dtype})x[{_coord_idx}];
366
+ }}"""
367
+ )
368
+ else:
369
+ ops.append(
370
+ f"""
371
+ out = ({internal_dtype})x[{_coord_idx}];"""
372
+ )
373
+
374
+ elif order == 1:
375
+ if batched:
376
+ ops.append(
377
+ """
378
+ int ic_0 = (int) c_0 * sx_0;"""
379
+ )
380
+ for j in range(int(batched), ndim):
381
+ # get coordinates for linear interpolation along axis j
382
+ ops.append(
383
+ f"""
384
+ {int_t} cf_{j} = ({int_t})floor((double)c_{j});
385
+ {int_t} cc_{j} = cf_{j} + 1;
386
+ {int_t} n_{j} = (c_{j} == cf_{j}) ? 1 : 2; // points needed
387
+ """
388
+ )
389
+
390
+ if mode == "wrap":
391
+ ops.append(
392
+ f"""
393
+ double dcoordf = c_{j};
394
+ double dcoordc = c_{j} + 1;"""
395
+ )
396
+ else:
397
+ # handle boundaries for extension modes.
398
+ ops.append(
399
+ f"""
400
+ {int_t} cf_bounded_{j} = cf_{j};
401
+ {int_t} cc_bounded_{j} = cc_{j};"""
402
+ )
403
+
404
+ if mode != "constant":
405
+ if mode == "wrap":
406
+ ixvar = "dcoordf"
407
+ float_ix = True
408
+ else:
409
+ ixvar = f"cf_bounded_{j}"
410
+ float_ix = False
411
+ ops.append(
412
+ _util._generate_boundary_condition_ops(
413
+ mode, ixvar, f"xsize_{j}", int_t, float_ix
414
+ )
415
+ )
416
+
417
+ ixvar = "dcoordc" if mode == "wrap" else f"cc_bounded_{j}"
418
+ ops.append(
419
+ _util._generate_boundary_condition_ops(
420
+ mode, ixvar, f"xsize_{j}", int_t, float_ix
421
+ )
422
+ )
423
+ if mode == "wrap":
424
+ ops.append(
425
+ f"""
426
+ {int_t} cf_bounded_{j} = ({int_t})floor(dcoordf);;
427
+ {int_t} cc_bounded_{j} = ({int_t})floor(dcoordf + 1);;
428
+ """
429
+ )
430
+
431
+ ops.append(
432
+ f"""
433
+ for (int s_{j} = 0; s_{j} < n_{j}; s_{j}++)
434
+ {{
435
+ W w_{j};
436
+ {int_t} ic_{j};
437
+ if (s_{j} == 0)
438
+ {{
439
+ w_{j} = (W)cc_{j} - c_{j};
440
+ ic_{j} = cf_bounded_{j} * sx_{j};
441
+ }} else
442
+ {{
443
+ w_{j} = c_{j} - (W)cf_{j};
444
+ ic_{j} = cc_bounded_{j} * sx_{j};
445
+ }}"""
446
+ )
447
+ elif order > 1:
448
+ if mode == "grid-constant":
449
+ spline_mode = "constant"
450
+ elif mode == "nearest":
451
+ spline_mode = "nearest"
452
+ else:
453
+ spline_mode = _spline_prefilter_core._get_spline_mode(mode)
454
+
455
+ # wx, wy are temporary variables used during spline weight computation
456
+ ops.append(
457
+ f"""
458
+ W wx, wy;
459
+ {int_t} start;"""
460
+ )
461
+
462
+ if batched:
463
+ ops.append(
464
+ """
465
+ int ic_0 = (int) c_0 * sx_0;"""
466
+ )
467
+ for j in range(int(batched), ndim):
468
+ # determine weights along the current axis
469
+ ops.append(
470
+ f"""
471
+ W weights_{j}[{order + 1}];"""
472
+ )
473
+ ops.append(spline_weights_inline[order].format(j=j, order=order))
474
+
475
+ # get starting coordinate for spline interpolation along axis j
476
+ if mode in ["wrap"]:
477
+ ops.append(f"double dcoord = c_{j};")
478
+ coord_var = "dcoord"
479
+ ops.append(
480
+ _util._generate_boundary_condition_ops(
481
+ mode, coord_var, f"xsize_{j}", int_t, True
482
+ )
483
+ )
484
+ else:
485
+ coord_var = f"(double)c_{j}"
486
+
487
+ if order & 1:
488
+ op_str = """
489
+ start = ({int_t})floor({coord_var}) - {order_2};"""
490
+ else:
491
+ op_str = """
492
+ start = ({int_t})floor({coord_var} + 0.5) - {order_2};"""
493
+ ops.append(
494
+ op_str.format(int_t=int_t, coord_var=coord_var, order_2=order // 2)
495
+ )
496
+
497
+ # set of coordinate values within spline footprint along axis j
498
+ ops.append(f"""{int_t} ci_{j}[{order + 1}];""")
499
+ for k in range(order + 1):
500
+ ixvar = f"ci_{j}[{k}]"
501
+ ops.append(
502
+ f"""
503
+ {ixvar} = start + {k};"""
504
+ )
505
+ ops.append(
506
+ _util._generate_boundary_condition_ops(
507
+ spline_mode, ixvar, f"xsize_{j}", int_t
508
+ )
509
+ )
510
+
511
+ # loop over the order + 1 values in the spline filter
512
+ ops.append(
513
+ f"""
514
+ W w_{j};
515
+ {int_t} ic_{j};
516
+ for (int k_{j} = 0; k_{j} <= {order}; k_{j}++)
517
+ {{
518
+ w_{j} = weights_{j}[k_{j}];
519
+ ic_{j} = ci_{j}[k_{j}] * sx_{j};
520
+ """
521
+ )
522
+
523
+ if order > 0:
524
+ _weight = " * ".join([f"w_{j}" for j in range(int(batched), ndim)])
525
+ _coord_idx = " + ".join([f"ic_{j}" for j in range(ndim)])
526
+ if mode == "grid-constant" or (order > 1 and mode == "constant"):
527
+ _cond = " || ".join([f"(ic_{j} < 0)" for j in range(ndim)])
528
+ ops.append(
529
+ f"""
530
+ if ({_cond}) {{
531
+ out += {cval} * ({internal_dtype})({_weight});
532
+ }} else {{
533
+ {internal_dtype} val = ({internal_dtype})x[{_coord_idx}];
534
+ out += val * ({internal_dtype})({_weight});
535
+ }}"""
536
+ )
537
+ else:
538
+ ops.append(
539
+ f"""
540
+ {internal_dtype} val = ({internal_dtype})x[{_coord_idx}];
541
+ out += val * ({internal_dtype})({_weight});"""
542
+ )
543
+
544
+ ops.append("}" * (ndim - int(batched)))
545
+
546
+ if mode == "constant":
547
+ ops.append("}")
548
+
549
+ if integer_output:
550
+ ops.append("y = (Y)rint((double)out);")
551
+ else:
552
+ ops.append("y = (Y)out;")
553
+ operation = "\n".join(ops)
554
+
555
+ mode_str = mode.replace("-", "_") # avoid hyphen in kernel name
556
+ name = "cupyx_scipy_ndimage_interpolate_{}_order{}_{}_{}d_y{}".format(
557
+ name,
558
+ order,
559
+ mode_str,
560
+ ndim,
561
+ "_".join([f"{j}" for j in yshape]),
562
+ )
563
+ if uint_t == "size_t":
564
+ name += "_i64"
565
+ return operation, name
566
+
567
+
568
+ @cupy._util.memoize(for_each_device=True)
569
+ def _get_batched_affine_kernel(
570
+ ndim, large_int, yshape, mode, cval=0.0, order=1, integer_output=False, nprepad=0
571
+ ):
572
+ in_params = "raw X x, raw W mat"
573
+ out_params = "Y y"
574
+ operation, name = _generate_interp_custom(
575
+ coord_func=_get_coord_affine_batched,
576
+ ndim=ndim,
577
+ large_int=large_int,
578
+ yshape=yshape,
579
+ mode=mode,
580
+ cval=cval,
581
+ order=order,
582
+ name="affine_batched",
583
+ integer_output=integer_output,
584
+ nprepad=nprepad,
585
+ batched=True,
586
+ )
587
+ return cupy.ElementwiseKernel(
588
+ in_params,
589
+ out_params,
590
+ operation,
591
+ name,
592
+ preamble=math_constants_preamble,
593
+ )
594
+
595
+
596
+ def affine_transform_batch(
597
+ input,
598
+ matrix,
599
+ offset=0.0,
600
+ output_shape=None,
601
+ output=None,
602
+ order=3,
603
+ mode="constant",
604
+ cval=0.0,
605
+ prefilter=True,
606
+ *,
607
+ batched=True,
608
+ ):
609
+ """
610
+ Apply an affine transformation.
611
+
612
+ Parameters
613
+ ----------
614
+ input : CupyArray
615
+ The input array.
616
+ matrix : CupyArray
617
+ The inverse coordinate transformation matrix, mapping output coordinates
618
+ to input coordinates. The shape of the matrix depends on the dimensions
619
+ of the input:
620
+
621
+ - ``(ndim, ndim)``: linear transformation matrix for each output coordinate.
622
+ - ``(ndim,)``: assume a diagonal 2D transformation matrix.
623
+ - ``(ndim + 1, ndim + 1)``: assume homogeneous coordinates (ignores `offset`).
624
+ - ``(ndim, ndim + 1)``: as above, but omits the bottom row
625
+ ``[0, 0, ..., 1]``.
626
+
627
+ offset : float or sequence, optional
628
+ The offset into the array where the transform is applied. If a float,
629
+ `offset` is the same for each axis. If a sequence, `offset` should
630
+ contain one value for each axis. Default is 0.0.
631
+ output_shape : tuple of ints, optional
632
+ Shape tuple of the output.
633
+ output : CupyArray or dtype, optional
634
+ The array in which to place the output, or the dtype of the returned array.
635
+ order : int, optional
636
+ The order of the spline interpolation. Must be in the range 0-5.
637
+ Default is 3.
638
+ mode : str, optional
639
+ Determines how input points outside the boundaries are handled.
640
+ Default is 'constant'. Available options are
641
+
642
+ +---------------+----------------------------------------------------+
643
+ | 'constant' | Fill with a constant value |
644
+ +---------------+----------------------------------------------------+
645
+ | 'nearest' | Use the nearest pixel's value |
646
+ +---------------+----------------------------------------------------+
647
+ | 'mirror' | Mirror the pixels at the boundary |
648
+ +---------------+----------------------------------------------------+
649
+ | 'reflect' | Reflect the pixels at the boundary |
650
+ +---------------+----------------------------------------------------+
651
+ | 'wrap' | Wrap the pixels at the boundary |
652
+ +---------------+----------------------------------------------------+
653
+ | 'grid-mirror' | Mirror the grid at the boundary |
654
+ +---------------+----------------------------------------------------+
655
+ | 'grid-wrap' | Wrap the grid at the boundary |
656
+ +---------------+----------------------------------------------------+
657
+ | 'grid- | Use a constant value for grid points outside the |
658
+ | constant' | boundary |
659
+ +---------------+----------------------------------------------------+
660
+ | 'opencv' | OpenCV border mode |
661
+ +---------------+----------------------------------------------------+
662
+ cval : scalar, optional
663
+ Value used for points outside the boundaries of the input if
664
+ ``mode='constant'`` or ``mode='opencv'``. Default is 0.0.
665
+ prefilter : bool, optional
666
+ Whether to prefilter the input array with `spline_filter` before
667
+ interpolation. Default is True.
668
+ batched : bool, optional
669
+ Whether the input has a leading batch dimension. Default is False.
670
+
671
+ Returns
672
+ -------
673
+ CupyArray or None
674
+ The transformed input. If `output` is given as a parameter,
675
+ None is returned.
676
+
677
+ Notes
678
+ -----
679
+ When `prefilter` is True and `order > 1`, a temporary `float64` array
680
+ of filtered values is created. If `prefilter` is False and `order > 1`,
681
+ the output may be slightly blurred unless the input is prefiltered.
682
+
683
+ When `batched` is True, the function treats the first dimension of the
684
+ input as a batch dimension.
685
+ """
686
+ _interpolation._check_parameter("affine_transform", order, mode)
687
+
688
+ offset = _util._fix_sequence_arg(offset, input.ndim, "offset", float)
689
+
690
+ if matrix.ndim != 2 or matrix.shape[0] < 1:
691
+ raise RuntimeError("no proper affine matrix provided")
692
+
693
+ if matrix.shape[0] == matrix.shape[1] - 1:
694
+ offset = matrix[:, -1]
695
+ matrix = matrix[:, :-1]
696
+ elif matrix.shape[0] == input.ndim + 1 - int(batched):
697
+ offset = matrix[:-1, -1]
698
+ matrix = matrix[:-1, :-1]
699
+
700
+ if output_shape is None:
701
+ output_shape = input.shape
702
+
703
+ matrix = matrix.astype(cupy.float64, copy=False)
704
+ ndim = input.ndim
705
+ output = _util._get_output(output, input, shape=output_shape)
706
+ if input.dtype.kind in "iu":
707
+ input = input.astype(cupy.float32)
708
+ filtered, nprepad = _filter_input(input, prefilter, mode, cval, order, batched)
709
+
710
+ integer_output = output.dtype.kind in "iu"
711
+ _util._check_cval(mode, cval, integer_output)
712
+ large_int = (
713
+ max(_core.internal.prod(input.shape), _core.internal.prod(output_shape))
714
+ > 1 << 31
715
+ )
716
+
717
+ kernel = _interp_kernels._get_affine_kernel
718
+ if batched:
719
+ kernel = _get_batched_affine_kernel
720
+ kern = kernel(
721
+ ndim,
722
+ large_int,
723
+ output_shape,
724
+ mode,
725
+ cval=cval,
726
+ order=order,
727
+ integer_output=integer_output,
728
+ nprepad=nprepad,
729
+ )
730
+ m = cupy.zeros((ndim - int(batched), ndim + 1 - int(batched)), dtype=cupy.float64)
731
+ m[:, :-1] = matrix
732
+ m[:, -1] = cupy.asarray(offset, dtype=cupy.float64)
733
+ kern(filtered, m, output)
734
+ return output