pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +1 -5
  2. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/match_template.py +177 -226
  3. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/postprocess.py +69 -47
  4. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocess.py +10 -23
  5. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +98 -28
  6. pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
  7. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/METADATA +15 -15
  8. pytme-0.3.1.dist-info/RECORD +133 -0
  9. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +1 -0
  10. pytme-0.3.1.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +1 -5
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +118 -99
  14. scripts/match_template.py +177 -226
  15. scripts/match_template_filters.py +1200 -0
  16. scripts/postprocess.py +69 -47
  17. scripts/preprocess.py +10 -23
  18. scripts/preprocessor_gui.py +98 -28
  19. scripts/pytme_runner.py +1223 -0
  20. scripts/refine_matches.py +156 -387
  21. tests/data/.DS_Store +0 -0
  22. tests/data/Blurring/.DS_Store +0 -0
  23. tests/data/Maps/.DS_Store +0 -0
  24. tests/data/Raw/.DS_Store +0 -0
  25. tests/data/Structures/.DS_Store +0 -0
  26. tests/preprocessing/test_frequency_filters.py +19 -10
  27. tests/preprocessing/test_utils.py +18 -0
  28. tests/test_analyzer.py +122 -122
  29. tests/test_backends.py +4 -9
  30. tests/test_density.py +0 -1
  31. tests/test_matching_cli.py +30 -30
  32. tests/test_matching_data.py +5 -5
  33. tests/test_matching_utils.py +11 -61
  34. tests/test_rotations.py +1 -1
  35. tme/__version__.py +1 -1
  36. tme/analyzer/__init__.py +1 -1
  37. tme/analyzer/_utils.py +5 -8
  38. tme/analyzer/aggregation.py +28 -9
  39. tme/analyzer/base.py +25 -36
  40. tme/analyzer/peaks.py +49 -122
  41. tme/analyzer/proxy.py +1 -0
  42. tme/backends/_jax_utils.py +31 -28
  43. tme/backends/_numpyfftw_utils.py +270 -0
  44. tme/backends/cupy_backend.py +11 -54
  45. tme/backends/jax_backend.py +72 -48
  46. tme/backends/matching_backend.py +6 -51
  47. tme/backends/mlx_backend.py +1 -27
  48. tme/backends/npfftw_backend.py +95 -90
  49. tme/backends/pytorch_backend.py +5 -26
  50. tme/density.py +7 -10
  51. tme/extensions.cpython-311-darwin.so +0 -0
  52. tme/filters/__init__.py +2 -2
  53. tme/filters/_utils.py +32 -7
  54. tme/filters/bandpass.py +225 -186
  55. tme/filters/ctf.py +138 -87
  56. tme/filters/reconstruction.py +38 -9
  57. tme/filters/wedge.py +98 -112
  58. tme/filters/whitening.py +1 -6
  59. tme/mask.py +341 -0
  60. tme/matching_data.py +20 -44
  61. tme/matching_exhaustive.py +46 -56
  62. tme/matching_optimization.py +2 -1
  63. tme/matching_scores.py +216 -412
  64. tme/matching_utils.py +82 -424
  65. tme/memory.py +1 -1
  66. tme/orientations.py +16 -8
  67. tme/parser.py +109 -29
  68. tme/preprocessor.py +2 -2
  69. tme/rotations.py +1 -1
  70. pytme-0.3b0.dist-info/RECORD +0 -122
  71. pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
  72. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
  73. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
@@ -10,16 +10,19 @@ from typing import Tuple
10
10
  from functools import partial
11
11
 
12
12
  import jax.numpy as jnp
13
- from jax import pmap, lax
13
+ from jax import pmap, lax, vmap
14
14
 
15
15
  from ..types import BackendArray
16
16
  from ..backends import backend as be
17
17
  from ..matching_utils import normalize_template as _normalize_template
18
18
 
19
19
 
20
+ __all__ = ["scan"]
21
+
22
+
20
23
  def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
21
24
  """
22
- Computes :py:meth:`tme.matching_exhaustive.cc_setup`.
25
+ Computes :py:meth:`tme.matching_scores.cc_setup`.
23
26
  """
24
27
  template_ft = jnp.fft.rfftn(template, s=template.shape)
25
28
  template_ft = template_ft.at[:].multiply(ft_target)
@@ -28,18 +31,17 @@ def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
28
31
 
29
32
 
30
33
  def _flc_scoring(
31
- template: BackendArray,
32
- template_mask: BackendArray,
33
34
  ft_target: BackendArray,
34
35
  ft_target2: BackendArray,
36
+ template: BackendArray,
37
+ template_mask: BackendArray,
35
38
  n_observations: BackendArray,
36
39
  eps: float,
37
40
  **kwargs,
38
41
  ) -> BackendArray:
39
42
  """
40
- Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
43
+ Computes :py:meth:`tme.matching_scores.flc_scoring`.
41
44
  """
42
- correlation = _correlate(template=template, ft_target=ft_target)
43
45
  inv_denominator = _reciprocal_target_std(
44
46
  ft_target=ft_target,
45
47
  ft_target2=ft_target2,
@@ -47,18 +49,17 @@ def _flc_scoring(
47
49
  eps=eps,
48
50
  n_observations=n_observations,
49
51
  )
50
- correlation = correlation.at[:].multiply(inv_denominator)
51
- return correlation
52
+ return _flcSphere_scoring(ft_target, template, inv_denominator)
52
53
 
53
54
 
54
55
  def _flcSphere_scoring(
55
- template: BackendArray,
56
56
  ft_target: BackendArray,
57
+ template: BackendArray,
57
58
  inv_denominator: BackendArray,
58
59
  **kwargs,
59
60
  ) -> BackendArray:
60
61
  """
61
- Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
62
+ Computes :py:meth:`tme.matching_scores.corr_scoring`.
62
63
  """
63
64
  correlation = _correlate(template=template, ft_target=ft_target)
64
65
  correlation = correlation.at[:].multiply(inv_denominator)
@@ -77,7 +78,7 @@ def _reciprocal_target_std(
77
78
 
78
79
  See Also
79
80
  --------
80
- :py:meth:`tme.matching_exhaustive.flc_scoring`.
81
+ :py:meth:`tme.matching_scores.flc_scoring`.
81
82
  """
82
83
  ft_shape = template_mask.shape
83
84
  ft_template_mask = jnp.fft.rfftn(template_mask, s=ft_shape)
@@ -114,7 +115,8 @@ def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
114
115
  @partial(
115
116
  pmap,
116
117
  in_axes=(0,) + (None,) * 6,
117
- static_broadcasted_argnums=[6, 7],
118
+ static_broadcasted_argnums=[6, 7, 8, 9],
119
+ axis_name="batch",
118
120
  )
119
121
  def scan(
120
122
  target: BackendArray,
@@ -125,9 +127,17 @@ def scan(
125
127
  target_filter: BackendArray,
126
128
  fast_shape: Tuple[int],
127
129
  rotate_mask: bool,
130
+ analyzer_class: object,
131
+ analyzer_kwargs: Tuple[Tuple],
128
132
  ) -> Tuple[BackendArray, BackendArray]:
129
133
  eps = jnp.finfo(template.dtype).resolution
130
134
 
135
+ kwargs = lax.switch(
136
+ lax.axis_index("batch"),
137
+ [lambda: analyzer_kwargs[i] for i in range(len(analyzer_kwargs))],
138
+ )
139
+ analyzer = analyzer_class(**be._tuple_to_dict(kwargs))
140
+
131
141
  if hasattr(target_filter, "shape"):
132
142
  target = _apply_fourier_filter(target, target_filter)
133
143
 
@@ -150,7 +160,7 @@ def scan(
150
160
  _template_filter_func = _apply_fourier_filter
151
161
 
152
162
  def _sample_transform(ret, rotation_matrix):
153
- max_scores, rotations, index = ret
163
+ state, index = ret
154
164
  template_rot, template_mask_rot = be.rigid_transform(
155
165
  arr=template,
156
166
  arr_mask=template_mask,
@@ -163,27 +173,20 @@ def scan(
163
173
  template_rot = _normalize_template(
164
174
  template_rot, template_mask_rot, n_observations
165
175
  )
166
- template_rot = be.topleft_pad(template_rot, fast_shape)
167
- template_mask_rot = be.topleft_pad(template_mask_rot, fast_shape)
176
+ rot_pad = be.topleft_pad(template_rot, fast_shape)
177
+ mask_rot_pad = be.topleft_pad(template_mask_rot, fast_shape)
168
178
 
169
179
  scores = scoring_func(
170
- template=template_rot,
171
- template_mask=template_mask_rot,
180
+ template=rot_pad,
181
+ template_mask=mask_rot_pad,
172
182
  ft_target=ft_target,
173
183
  ft_target2=ft_target2,
174
184
  inv_denominator=inv_denominator,
175
185
  n_observations=n_observations,
176
186
  eps=eps,
177
187
  )
178
- max_scores, rotations = be.max_score_over_rotations(
179
- scores, max_scores, rotations, index
180
- )
181
- return (max_scores, rotations, index + 1), None
182
-
183
- score_space = jnp.zeros(fast_shape)
184
- rotation_space = jnp.full(shape=fast_shape, dtype=jnp.int32, fill_value=-1)
185
- (score_space, rotation_space, _), _ = lax.scan(
186
- _sample_transform, (score_space, rotation_space, 0), rotations
187
- )
188
+ state = analyzer(state, scores, rotation_matrix, rotation_index=index)
189
+ return (state, index + 1), None
188
190
 
189
- return score_space, rotation_space
191
+ (state, _), _ = lax.scan(_sample_transform, (analyzer.init_state(), 0), rotations)
192
+ return state
@@ -0,0 +1,270 @@
1
+ #!/usr/bin/env python
2
+ #
3
+ # Henry Gomersall
4
+ # heng@kedevelopments.co.uk
5
+ #
6
+ # All rights reserved.
7
+ #
8
+ # Redistribution and use in source and binary forms, with or without
9
+ # modification, are permitted provided that the following conditions are met:
10
+ #
11
+ # * Redistributions of source code must retain the above copyright notice, this
12
+ # list of conditions and the following disclaimer.
13
+ #
14
+ # * Redistributions in binary form must reproduce the above copyright notice,
15
+ # this list of conditions and the following disclaimer in the documentation
16
+ # and/or other materials provided with the distribution.
17
+ #
18
+ # * Neither the name of the copyright holder nor the names of its contributors
19
+ # may be used to endorse or promote products derived from this software without
20
+ # specific prior written permission.
21
+ #
22
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
23
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
24
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
25
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
26
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
27
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
28
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
29
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
30
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
31
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32
+ # POSSIBILITY OF SUCH DAMAGE.
33
+ #
34
+
35
+ # This code has been adapted to add support for the out argument in rfftn, irfftn
36
+ # to allow for reusing existing array buffers
37
+ # Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
38
+
39
+ import threading
40
+
41
+ import pyfftw
42
+ import numpy as np
43
+ import pyfftw.builders as builders
44
+ from pyfftw.interfaces import cache
45
+ from pyfftw.builders._utils import _norm_args, _default_effort, _default_threads
46
+
47
+
48
+ def _Xfftn(
49
+ a,
50
+ s,
51
+ axes,
52
+ overwrite_input,
53
+ planner_effort,
54
+ threads,
55
+ auto_align_input,
56
+ auto_contiguous,
57
+ calling_func,
58
+ normalise_idft=True,
59
+ ortho=False,
60
+ real_direction_flag=None,
61
+ output_array=None,
62
+ ):
63
+
64
+ work_with_copy = False
65
+
66
+ a = np.asanyarray(a)
67
+
68
+ try:
69
+ s = tuple(s)
70
+ except TypeError:
71
+ pass
72
+
73
+ try:
74
+ axes = tuple(axes)
75
+ except TypeError:
76
+ pass
77
+
78
+ if calling_func in ("dct", "dst"):
79
+ # real-to-real transforms require passing an additional flag argument
80
+ avoid_copy = False
81
+ args = (
82
+ overwrite_input,
83
+ planner_effort,
84
+ threads,
85
+ auto_align_input,
86
+ auto_contiguous,
87
+ avoid_copy,
88
+ real_direction_flag,
89
+ )
90
+ elif calling_func in ("irfft2", "irfftn"):
91
+ # overwrite_input is not an argument to irfft2 or irfftn
92
+ args = (planner_effort, threads, auto_align_input, auto_contiguous)
93
+
94
+ if not overwrite_input:
95
+ # Only irfft2 and irfftn have overwriting the input
96
+ # as the default (and so require the input array to
97
+ # be reloaded).
98
+ work_with_copy = True
99
+ else:
100
+ args = (
101
+ overwrite_input,
102
+ planner_effort,
103
+ threads,
104
+ auto_align_input,
105
+ auto_contiguous,
106
+ )
107
+
108
+ if not a.flags.writeable:
109
+ # Special case of a locked array - always work with a
110
+ # copy. See issue #92.
111
+ work_with_copy = True
112
+
113
+ if overwrite_input:
114
+ raise ValueError(
115
+ "overwrite_input cannot be True when the "
116
+ + "input array flags.writeable is False"
117
+ )
118
+
119
+ if work_with_copy:
120
+ # We make the copy before registering the key so that the
121
+ # copy's stride information will be cached since this will be
122
+ # used for planning. Make sure the copy is byte aligned to
123
+ # prevent further copying
124
+ a_original = a
125
+ a = pyfftw.empty_aligned(shape=a.shape, dtype=a.dtype)
126
+ a[...] = a_original
127
+
128
+ if cache.is_enabled():
129
+ alignment = a.ctypes.data % pyfftw.simd_alignment
130
+
131
+ key = (
132
+ calling_func,
133
+ a.shape,
134
+ a.strides,
135
+ a.dtype,
136
+ s.__hash__(),
137
+ axes.__hash__(),
138
+ alignment,
139
+ args,
140
+ threading.get_ident(),
141
+ )
142
+
143
+ try:
144
+ if key in cache._fftw_cache:
145
+ FFTW_object = cache._fftw_cache.lookup(key)
146
+ else:
147
+ FFTW_object = None
148
+
149
+ except KeyError:
150
+ # This occurs if the object has fallen out of the cache between
151
+ # the check and the lookup
152
+ FFTW_object = None
153
+
154
+ if not cache.is_enabled() or FFTW_object is None:
155
+
156
+ # If we're going to create a new FFTW object and are not
157
+ # working with a copy, then we need to copy the input array to
158
+ # preserve it, otherwise we can't actually take the transform
159
+ # of the input array! (in general, we have to assume that the
160
+ # input array will be destroyed during planning).
161
+ if not work_with_copy:
162
+ a_copy = a.copy()
163
+
164
+ planner_args = (a, s, axes) + args
165
+
166
+ FFTW_object = getattr(builders, calling_func)(*planner_args)
167
+
168
+ # Only copy if the input array is what was actually used
169
+ # (otherwise it shouldn't be overwritten)
170
+ if not work_with_copy and FFTW_object.input_array is a:
171
+ a[:] = a_copy
172
+
173
+ if cache.is_enabled():
174
+ cache._fftw_cache.insert(FFTW_object, key)
175
+
176
+ output_array = FFTW_object(normalise_idft=normalise_idft, ortho=ortho)
177
+
178
+ else:
179
+ orig_output_array = FFTW_object.output_array
180
+ output_shape = orig_output_array.shape
181
+ output_dtype = orig_output_array.dtype
182
+ output_alignment = FFTW_object.output_alignment
183
+
184
+ if output_array is None:
185
+ output_array = pyfftw.empty_aligned(
186
+ output_shape, output_dtype, n=output_alignment
187
+ )
188
+
189
+ FFTW_object(
190
+ input_array=a,
191
+ output_array=output_array,
192
+ normalise_idft=normalise_idft,
193
+ ortho=ortho,
194
+ )
195
+
196
+ return output_array
197
+
198
+
199
+ def rfftn(
200
+ a,
201
+ s=None,
202
+ axes=None,
203
+ norm=None,
204
+ overwrite_input=False,
205
+ planner_effort=None,
206
+ threads=None,
207
+ auto_align_input=True,
208
+ auto_contiguous=True,
209
+ out=None,
210
+ ):
211
+ """Perform an n-D real FFT.
212
+
213
+ The first four arguments are as per :func:`numpy.fft.rfftn`;
214
+ the rest of the arguments are documented
215
+ in the :ref:`additional arguments docs<interfaces_additional_args>`.
216
+ """
217
+ calling_func = "rfftn"
218
+ planner_effort = _default_effort(planner_effort)
219
+ threads = _default_threads(threads)
220
+
221
+ return _Xfftn(
222
+ a,
223
+ s,
224
+ axes,
225
+ overwrite_input,
226
+ planner_effort,
227
+ threads,
228
+ auto_align_input,
229
+ auto_contiguous,
230
+ calling_func,
231
+ **_norm_args(norm),
232
+ output_array=out,
233
+ )
234
+
235
+
236
+ def irfftn(
237
+ a,
238
+ s=None,
239
+ axes=None,
240
+ norm=None,
241
+ overwrite_input=False,
242
+ planner_effort=None,
243
+ threads=None,
244
+ auto_align_input=True,
245
+ auto_contiguous=True,
246
+ out=None,
247
+ ):
248
+ """Perform an n-D real inverse FFT.
249
+
250
+ The first four arguments are as per :func:`numpy.fft.rfftn`;
251
+ the rest of the arguments are documented
252
+ in the :ref:`additional arguments docs<interfaces_additional_args>`.
253
+ """
254
+ calling_func = "irfftn"
255
+ planner_effort = _default_effort(planner_effort)
256
+ threads = _default_threads(threads)
257
+
258
+ return _Xfftn(
259
+ a,
260
+ s,
261
+ axes,
262
+ overwrite_input,
263
+ planner_effort,
264
+ threads,
265
+ auto_align_input,
266
+ auto_contiguous,
267
+ calling_func,
268
+ **_norm_args(norm),
269
+ output_array=out,
270
+ )
@@ -6,12 +6,9 @@ Copyright (c) 2023 European Molecular Biology Laboratory
6
6
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
 
9
- import warnings
9
+ from typing import Tuple, List
10
10
  from importlib.util import find_spec
11
11
  from contextlib import contextmanager
12
- from typing import Tuple, Callable, List
13
-
14
- import numpy as np
15
12
 
16
13
  from .npfftw_backend import NumpyFFTWBackend
17
14
  from ..types import CupyArray, NDArray, shm_type
@@ -114,54 +111,14 @@ class CupyBackend(NumpyFFTWBackend):
114
111
  def unravel_index(self, indices, shape):
115
112
  return self._array_backend.unravel_index(indices=indices, dims=shape)
116
113
 
117
- def build_fft(
118
- self,
119
- fwd_shape: Tuple[int],
120
- inv_shape: Tuple[int],
121
- inv_output_shape: Tuple[int] = None,
122
- fwd_axes: Tuple[int] = None,
123
- inv_axes: Tuple[int] = None,
124
- **kwargs,
125
- ) -> Tuple[Callable, Callable]:
126
- cache = self._array_backend.fft.config.get_plan_cache()
127
- current_device = self._array_backend.cuda.device.get_device_id()
128
-
129
- previous_transform = [fwd_shape, inv_shape]
130
- if current_device in PLAN_CACHE:
131
- previous_transform = PLAN_CACHE[current_device]
132
-
133
- real_diff, cmplx_diff = True, True
134
- if len(fwd_shape) == len(previous_transform[0]):
135
- real_diff = fwd_shape == previous_transform[0]
136
- if len(inv_shape) == len(previous_transform[1]):
137
- cmplx_diff = inv_shape == previous_transform[1]
138
-
139
- if real_diff or cmplx_diff:
140
- cache.clear()
141
-
142
- rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
143
- irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
144
- irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
145
-
146
- def rfftn(
147
- arr: CupyArray, out: CupyArray = None, s=rfft_shape, axes=fwd_axes
148
- ) -> CupyArray:
149
- return self.rfftn(arr, s=s, axes=fwd_axes)
150
-
151
- def irfftn(
152
- arr: CupyArray, out: CupyArray = None, s=irfft_shape, axes=inv_axes
153
- ) -> CupyArray:
154
- return self.irfftn(arr, s=s, axes=inv_axes)
155
-
156
- PLAN_CACHE[current_device] = [fwd_shape, inv_shape]
157
-
158
- return rfftn, irfftn
114
+ def free_cache(self):
115
+ self._array_backend.fft.config.get_plan_cache().clear()
159
116
 
160
117
  def rfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
161
118
  return self._cufft.rfftn(arr, **kwargs)
162
119
 
163
120
  def irfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
164
- return self._cufft.irfftn(arr, **kwargs)
121
+ return self._cufft.irfftn(arr, **kwargs).astype(self._float_dtype)
165
122
 
166
123
  def compute_convolution_shapes(
167
124
  self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
@@ -239,13 +196,13 @@ class CupyBackend(NumpyFFTWBackend):
239
196
  )
240
197
  return None
241
198
 
242
- # if data.ndim == 3 and cache and self.texture_available:
243
- # # Device memory pool (should) come to rescue performance
244
- # temp = self.zeros(data.shape, data.dtype)
245
- # texture = self._get_texture(data, order=order, prefilter=prefilter)
246
- # texture.affine(transform_m=matrix, profile=False, output=temp)
247
- # output[out_slice] = temp
248
- # return None
199
+ if data.ndim == 3 and cache and self.texture_available and not batched:
200
+ # Device memory pool (should) come to rescue performance
201
+ temp = self.zeros(data.shape, data.dtype)
202
+ texture = self._get_texture(data, order=order, prefilter=prefilter)
203
+ texture.affine(transform_m=matrix, profile=False, output=temp)
204
+ output[out_slice] = temp
205
+ return None
249
206
 
250
207
  self.affine_transform(
251
208
  input=data,