pytme 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2.dev0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pytme-0.3.2.dev0.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/match_template.py +213 -196
  3. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/postprocess.py +40 -78
  4. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/preprocess.py +4 -5
  5. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/preprocessor_gui.py +49 -103
  6. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/pytme_runner.py +46 -69
  7. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/METADATA +2 -1
  8. pytme-0.3.2.dev0.dist-info/RECORD +136 -0
  9. scripts/estimate_ram_usage.py +97 -0
  10. scripts/match_template.py +213 -196
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +40 -78
  13. scripts/preprocess.py +4 -5
  14. scripts/preprocessor_gui.py +49 -103
  15. scripts/pytme_runner.py +46 -69
  16. tests/preprocessing/test_compose.py +31 -30
  17. tests/preprocessing/test_frequency_filters.py +17 -32
  18. tests/preprocessing/test_preprocessor.py +0 -19
  19. tests/preprocessing/test_utils.py +13 -1
  20. tests/test_analyzer.py +2 -10
  21. tests/test_backends.py +47 -18
  22. tests/test_density.py +72 -13
  23. tests/test_extensions.py +1 -0
  24. tests/test_matching_cli.py +23 -9
  25. tests/test_matching_exhaustive.py +5 -5
  26. tests/test_matching_utils.py +3 -3
  27. tests/test_orientations.py +12 -0
  28. tests/test_rotations.py +13 -23
  29. tests/test_structure.py +1 -7
  30. tme/__version__.py +1 -1
  31. tme/analyzer/aggregation.py +47 -16
  32. tme/analyzer/base.py +34 -0
  33. tme/analyzer/peaks.py +26 -13
  34. tme/analyzer/proxy.py +14 -0
  35. tme/backends/_jax_utils.py +91 -68
  36. tme/backends/cupy_backend.py +6 -19
  37. tme/backends/jax_backend.py +103 -98
  38. tme/backends/matching_backend.py +0 -17
  39. tme/backends/mlx_backend.py +0 -29
  40. tme/backends/npfftw_backend.py +100 -97
  41. tme/backends/pytorch_backend.py +65 -78
  42. tme/cli.py +2 -2
  43. tme/density.py +44 -57
  44. tme/extensions.cpython-311-darwin.so +0 -0
  45. tme/filters/_utils.py +52 -24
  46. tme/filters/bandpass.py +99 -105
  47. tme/filters/compose.py +133 -39
  48. tme/filters/ctf.py +51 -102
  49. tme/filters/reconstruction.py +67 -122
  50. tme/filters/wedge.py +296 -325
  51. tme/filters/whitening.py +39 -75
  52. tme/mask.py +2 -2
  53. tme/matching_data.py +87 -15
  54. tme/matching_exhaustive.py +70 -120
  55. tme/matching_optimization.py +9 -63
  56. tme/matching_scores.py +261 -100
  57. tme/matching_utils.py +150 -91
  58. tme/memory.py +1 -0
  59. tme/orientations.py +17 -3
  60. tme/preprocessor.py +0 -239
  61. tme/rotations.py +102 -70
  62. tme/structure.py +601 -631
  63. tme/types.py +1 -0
  64. pytme-0.3.1.post2.dist-info/RECORD +0 -133
  65. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/estimate_memory_usage.py +0 -0
  66. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/WHEEL +0 -0
  67. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/entry_points.txt +0 -0
  68. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/licenses/LICENSE +0 -0
  69. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/top_level.txt +0 -0
tme/analyzer/peaks.py CHANGED
@@ -18,7 +18,6 @@ from .base import AbstractAnalyzer
18
18
  from ._utils import score_to_cart
19
19
  from ..backends import backend as be
20
20
  from ..types import BackendArray, NDArray
21
- from ..rotations import euler_to_rotationmatrix
22
21
  from ..matching_utils import split_shape, compute_extraction_box
23
22
 
24
23
  __all__ = [
@@ -182,6 +181,7 @@ class PeakCaller(AbstractAnalyzer):
182
181
  min_score: float = None,
183
182
  max_score: float = None,
184
183
  batch_dims: Tuple[int] = None,
184
+ projection_dims: Tuple[int] = None,
185
185
  shm_handler: object = None,
186
186
  **kwargs,
187
187
  ):
@@ -197,9 +197,13 @@ class PeakCaller(AbstractAnalyzer):
197
197
  self.min_distance = int(min_distance)
198
198
  self.min_boundary_distance = int(min_boundary_distance)
199
199
 
200
- self.batch_dims = batch_dims
200
+ self.batch_dims = ()
201
201
  if batch_dims is not None:
202
- self.batch_dims = tuple(int(x) for x in self.batch_dims)
202
+ self.batch_dims = tuple(int(x) for x in batch_dims)
203
+
204
+ self.projection_dims = ()
205
+ if projection_dims is not None:
206
+ self.projection_dims = tuple(int(x) for x in projection_dims)
203
207
 
204
208
  self.min_score, self.max_score = min_score, max_score
205
209
 
@@ -231,7 +235,7 @@ class PeakCaller(AbstractAnalyzer):
231
235
 
232
236
  rdim = len(self.shape)
233
237
  if self.batch_dims:
234
- rdim -= len(self.batch_dims)
238
+ rdim = rdim - len(self.batch_dims) + len(self.projection_dims)
235
239
 
236
240
  rotations = be.full(
237
241
  (self.num_peaks, rdim, rdim), fill_value=0, dtype=be._float_dtype
@@ -388,6 +392,20 @@ class PeakCaller(AbstractAnalyzer):
388
392
 
389
393
  return state
390
394
 
395
+ def correct_background(self, state, mean, inv_std=1, **kwargs):
396
+ arr_type = type(be.zeros((1,), be._float_dtype))
397
+ translations, rotations, scores, details = state
398
+
399
+ if isinstance(mean, arr_type):
400
+ mean = mean[tuple(be.astype(translations.T, int))]
401
+ scores = be.subtract(scores, mean, out=scores)
402
+
403
+ if isinstance(inv_std, arr_type):
404
+ inv_std = inv_std[tuple(be.astype(translations.T, int))]
405
+ scores = be.multiply(scores, inv_std, out=scores)
406
+
407
+ return translations, rotations, scores, details
408
+
391
409
  @classmethod
392
410
  def merge(cls, results=List[Tuple], **kwargs) -> Tuple:
393
411
  """
@@ -778,6 +796,9 @@ class PeakCallerRecursiveMasking(PeakCaller):
778
796
  mask = be.to_backend_array(mask)
779
797
  mask_buffer = be.zeros(mask.shape, dtype=mask.dtype)
780
798
 
799
+ if min_score is None:
800
+ min_score = self.min_score
801
+
781
802
  if min_score is None:
782
803
  min_score = be.min(scores) - 1
783
804
 
@@ -849,15 +870,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
849
870
  """
850
871
  if rotation_space is None or rotation_mapping is None:
851
872
  return rotation_matrix
852
-
853
- rotation = rotation_mapping[rotation_space[tuple(peak)]]
854
-
855
- # Old versions of rotation mapping contained Euler angles
856
- if rotation.ndim != 2:
857
- rotation = be.to_backend_array(
858
- euler_to_rotationmatrix(be.to_numpy_array(rotation))
859
- )
860
- return rotation
873
+ return rotation_mapping[rotation_space[tuple(peak)]]
861
874
 
862
875
 
863
876
  class PeakCallerScipy(PeakCaller):
tme/analyzer/proxy.py CHANGED
@@ -85,6 +85,16 @@ class StatelessSharedAnalyzerProxy:
85
85
  final_state = tuple(self._shared_to_object(x) for x in final_state)
86
86
  return self._analyzer.result(final_state, **kwargs)
87
87
 
88
+ def correct_background(self, state, *args, **kwargs):
89
+ if self._shared:
90
+ # Copy to not correct the internal score array across processes
91
+ backend_arr = type(be.zeros((1), dtype=be._float_dtype))
92
+ state = tuple(self._shared_to_object(x) for x in state)
93
+ state = tuple(
94
+ be.copy(x) if isinstance(x, backend_arr) else x for x in state
95
+ )
96
+ return self._analyzer.correct_background(state, *args, **kwargs)
97
+
88
98
  def merge(self, *args, **kwargs):
89
99
  return self._analyzer.merge(*args, **kwargs)
90
100
 
@@ -121,3 +131,7 @@ class SharedAnalyzerProxy(StatelessSharedAnalyzerProxy):
121
131
  def result(self, **kwargs):
122
132
  """Extract final result"""
123
133
  return super().result(self._state, **kwargs)
134
+
135
+ def correct_background(self, *args, **kwargs):
136
+ # We always assign to state as this operation can not be shared
137
+ self._state = super().correct_background(self._state, *args, **kwargs)
@@ -10,11 +10,11 @@ from typing import Tuple
10
10
  from functools import partial
11
11
 
12
12
  import jax.numpy as jnp
13
- from jax import pmap, lax, vmap
13
+ from jax import pmap, lax, jit
14
14
 
15
15
  from ..types import BackendArray
16
16
  from ..backends import backend as be
17
- from ..matching_utils import normalize_template as _normalize_template
17
+ from ..matching_utils import standardize, to_padded
18
18
 
19
19
 
20
20
  __all__ = ["scan", "setup_scan"]
@@ -62,15 +62,14 @@ def _flcSphere_scoring(
62
62
  Computes :py:meth:`tme.matching_scores.corr_scoring`.
63
63
  """
64
64
  correlation = _correlate(template=template, ft_target=ft_target)
65
- correlation = correlation.at[:].multiply(inv_denominator)
66
- return correlation
65
+ return correlation.at[:].multiply(inv_denominator)
67
66
 
68
67
 
69
68
  def _reciprocal_target_std(
70
69
  ft_target: BackendArray,
71
70
  ft_target2: BackendArray,
72
71
  template_mask: BackendArray,
73
- n_observations: float,
72
+ n_obs: float,
74
73
  eps: float,
75
74
  ) -> BackendArray:
76
75
  """
@@ -80,16 +79,16 @@ def _reciprocal_target_std(
80
79
  --------
81
80
  :py:meth:`tme.matching_scores.flc_scoring`.
82
81
  """
83
- ft_shape = template_mask.shape
84
- ft_template_mask = jnp.fft.rfftn(template_mask, s=ft_shape)
82
+ shape = template_mask.shape
83
+ ft_template_mask = jnp.fft.rfftn(template_mask, s=shape)
85
84
 
86
85
  # E(X^2)- E(X)^2
87
- exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask, s=ft_shape)
88
- exp_sq = exp_sq.at[:].divide(n_observations)
86
+ exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask, s=shape)
87
+ exp_sq = exp_sq.at[:].divide(n_obs)
89
88
 
90
89
  ft_template_mask = ft_template_mask.at[:].multiply(ft_target)
91
- sq_exp = jnp.fft.irfftn(ft_template_mask, s=ft_shape)
92
- sq_exp = sq_exp.at[:].divide(n_observations)
90
+ sq_exp = jnp.fft.irfftn(ft_template_mask, s=shape)
91
+ sq_exp = sq_exp.at[:].divide(n_obs)
93
92
  sq_exp = sq_exp.at[:].power(2)
94
93
 
95
94
  exp_sq = exp_sq.at[:].add(-sq_exp)
@@ -97,7 +96,7 @@ def _reciprocal_target_std(
97
96
  exp_sq = exp_sq.at[:].power(0.5)
98
97
 
99
98
  exp_sq = exp_sq.at[:].set(
100
- jnp.where(exp_sq <= eps, 0, jnp.reciprocal(exp_sq * n_observations))
99
+ jnp.where(exp_sq <= eps, 0, jnp.reciprocal(exp_sq * n_obs))
101
100
  )
102
101
  return exp_sq
103
102
 
@@ -108,32 +107,21 @@ def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> Backen
108
107
  return arr.at[:].set(jnp.fft.irfftn(arr_ft, s=arr.shape))
109
108
 
110
109
 
111
- def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
112
- return arr
113
-
114
-
115
- def _mask_scores(arr, mask):
116
- return arr.at[:].multiply(mask)
117
-
118
-
119
- def _select_config(analyzer_kwargs, device_idx):
120
- return analyzer_kwargs[device_idx]
121
-
122
-
123
- def setup_scan(analyzer_kwargs, callback_class, fast_shape, rotate_mask):
110
+ def setup_scan(analyzer_kwargs, analyzer, fast_shape, rotate_mask, match_projection):
124
111
  """Create separate scan function with initialized analyzer for each device"""
125
112
  device_scans = [
126
113
  partial(
127
114
  scan,
128
115
  fast_shape=fast_shape,
129
116
  rotate_mask=rotate_mask,
130
- analyzer=callback_class(**device_config),
131
- ) for device_config in analyzer_kwargs
117
+ analyzer=analyzer(**device_config),
118
+ )
119
+ for device_config in analyzer_kwargs
132
120
  ]
133
121
 
134
122
  @partial(
135
123
  pmap,
136
- in_axes=(0,) + (None,) * 6,
124
+ in_axes=(0,) + (None,) * 7,
137
125
  axis_name="batch",
138
126
  )
139
127
  def scan_combined(
@@ -144,6 +132,7 @@ def setup_scan(analyzer_kwargs, callback_class, fast_shape, rotate_mask):
144
132
  template_filter,
145
133
  target_filter,
146
134
  score_mask,
135
+ background_template,
147
136
  ):
148
137
  return lax.switch(
149
138
  lax.axis_index("batch"),
@@ -155,10 +144,13 @@ def setup_scan(analyzer_kwargs, callback_class, fast_shape, rotate_mask):
155
144
  template_filter,
156
145
  target_filter,
157
146
  score_mask,
147
+ background_template,
158
148
  )
149
+
159
150
  return scan_combined
160
151
 
161
152
 
153
+ @partial(jit, static_argnums=(8, 9, 10))
162
154
  def scan(
163
155
  target: BackendArray,
164
156
  template: BackendArray,
@@ -167,67 +159,98 @@ def scan(
167
159
  template_filter: BackendArray,
168
160
  target_filter: BackendArray,
169
161
  score_mask: BackendArray,
162
+ background_template: BackendArray,
170
163
  fast_shape: Tuple[int],
171
164
  rotate_mask: bool,
172
165
  analyzer: object,
173
- ) -> Tuple[BackendArray, BackendArray]:
166
+ ) -> Tuple:
174
167
  eps = jnp.finfo(template.dtype).resolution
175
168
 
176
- if hasattr(target_filter, "shape"):
169
+ if target_filter.shape != ():
177
170
  target = _apply_fourier_filter(target, target_filter)
178
171
 
179
172
  ft_target = jnp.fft.rfftn(target, s=fast_shape)
180
173
  ft_target2 = jnp.fft.rfftn(jnp.square(target), s=fast_shape)
181
- inv_denominator, target, scoring_func = None, None, _flc_scoring
174
+ _n_obs, _inv_denominator, target = None, None, None
175
+
176
+ unpadded_slice = tuple(slice(0, x) for x in template.shape)
177
+ rot_buffer, mask_rot_buffer = jnp.zeros(fast_shape), jnp.zeros(fast_shape)
182
178
  if not rotate_mask:
183
- n_observations = jnp.sum(template_mask)
184
- inv_denominator = _reciprocal_target_std(
179
+ _n_obs = jnp.sum(template_mask)
180
+ _inv_denominator = _reciprocal_target_std(
185
181
  ft_target=ft_target,
186
182
  ft_target2=ft_target2,
187
- template_mask=be.topleft_pad(template_mask, fast_shape),
183
+ template_mask=to_padded(mask_rot_buffer, template_mask, unpadded_slice),
188
184
  eps=eps,
189
- n_observations=n_observations,
185
+ n_obs=_n_obs,
190
186
  )
191
- ft_target2, scoring_func = None, _flcSphere_scoring
187
+ ft_target2 = None
192
188
 
193
- _template_filter_func = _identity
194
- if template_filter.shape != ():
195
- _template_filter_func = _apply_fourier_filter
189
+ mask_scores = score_mask.shape != ()
190
+ filter_template = template_filter.shape != ()
191
+ bg_correction = background_template.shape != ()
192
+ bg_scores = jnp.zeros(fast_shape) if bg_correction else 0
196
193
 
197
- _score_mask_func = _identity
198
- if score_mask.shape != ():
199
- _score_mask_func = _mask_scores
194
+ _template_mask_rot = template_mask
195
+ template_indices = be._index_grid(template.shape)
196
+ center = be.divide(be.to_backend_array(template.shape) - 1, 2)
200
197
 
201
198
  def _sample_transform(ret, rotation_matrix):
202
- state, index = ret
203
- template_rot, template_mask_rot = be.rigid_transform(
204
- arr=template,
205
- arr_mask=template_mask,
206
- rotation_matrix=rotation_matrix,
207
- order=1, # thats all we get for now
199
+ matrix = be._build_transform_matrix(
200
+ rotation_matrix=rotation_matrix, center=center
208
201
  )
202
+ indices = be._transform_indices(template_indices, matrix)
203
+
204
+ template_rot = be._interpolate(template, indices, order=1)
205
+ n_obs, template_mask_rot = _n_obs, _template_mask_rot
206
+ if rotate_mask:
207
+ template_mask_rot = be._interpolate(template_mask, indices, order=1)
208
+ n_obs = jnp.sum(template_mask_rot)
209
+
210
+ if filter_template:
211
+ template_rot = _apply_fourier_filter(template_rot, template_filter)
212
+ template_rot = standardize(template_rot, template_mask_rot, n_obs)
213
+
214
+ rot_pad = to_padded(rot_buffer, template_rot, unpadded_slice)
215
+
216
+ inv_denominator = _inv_denominator
217
+ if rotate_mask:
218
+ mask_rot_pad = to_padded(mask_rot_buffer, template_mask_rot, unpadded_slice)
219
+ inv_denominator = _reciprocal_target_std(
220
+ ft_target=ft_target,
221
+ ft_target2=ft_target2,
222
+ template_mask=mask_rot_pad,
223
+ n_obs=n_obs,
224
+ eps=eps,
225
+ )
226
+
227
+ scores = _flcSphere_scoring(ft_target, rot_pad, inv_denominator)
228
+ if mask_scores:
229
+ scores = scores.at[:].multiply(score_mask)
230
+
231
+ state, bg_scores, index = ret
232
+ state = analyzer(state, scores, rotation_matrix, rotation_index=index)
209
233
 
210
- n_observations = jnp.sum(template_mask_rot)
211
- template_rot = _template_filter_func(template_rot, template_filter)
212
- template_rot = _normalize_template(
213
- template_rot, template_mask_rot, n_observations
214
- )
215
- rot_pad = be.topleft_pad(template_rot, fast_shape)
216
- mask_rot_pad = be.topleft_pad(template_mask_rot, fast_shape)
234
+ if bg_correction:
235
+ template_rot = be._interpolate(background_template, indices, order=1)
236
+ if filter_template:
237
+ template_rot = _apply_fourier_filter(template_rot, template_filter)
238
+ template_rot = standardize(template_rot, template_mask_rot, n_obs)
217
239
 
218
- scores = scoring_func(
219
- template=rot_pad,
220
- template_mask=mask_rot_pad,
221
- ft_target=ft_target,
222
- ft_target2=ft_target2,
223
- inv_denominator=inv_denominator,
224
- n_observations=n_observations,
225
- eps=eps,
226
- )
227
- scores = _score_mask_func(scores, score_mask)
240
+ rot_pad = to_padded(rot_buffer, template_rot, unpadded_slice)
241
+ scores = _flcSphere_scoring(ft_target, rot_pad, inv_denominator)
242
+ bg_scores = jnp.maximum(bg_scores, scores)
228
243
 
229
- state = analyzer(state, scores, rotation_matrix, rotation_index=index)
230
- return (state, index + 1), None
244
+ return (state, bg_scores, index + 1), None
245
+
246
+ (state, bg_scores, _), _ = lax.scan(
247
+ _sample_transform, (analyzer.init_state(), bg_scores, 0), rotations
248
+ )
249
+
250
+ if bg_correction:
251
+ if mask_scores:
252
+ bg_scores = bg_scores.at[:].multiply(score_mask)
253
+ bg_scores = bg_scores.at[:].add(-be.mean(bg_scores))
254
+ state = analyzer.correct_background(state, bg_scores)
231
255
 
232
- (state, _), _ = lax.scan(_sample_transform, (analyzer.init_state(), 0), rotations)
233
256
  return state
@@ -33,7 +33,6 @@ class CupyBackend(NumpyFFTWBackend):
33
33
  import cupy as cp
34
34
  import cupyx.scipy.fft as cufft
35
35
  from cupyx.scipy.ndimage import affine_transform, maximum_filter
36
- from ._cupy_utils import affine_transform_batch
37
36
 
38
37
  float_dtype = cp.float32 if float_dtype is None else float_dtype
39
38
  complex_dtype = cp.complex64 if complex_dtype is None else complex_dtype
@@ -51,7 +50,6 @@ class CupyBackend(NumpyFFTWBackend):
51
50
  self._cufft = cufft
52
51
  self.maximum_filter = maximum_filter
53
52
  self.affine_transform = affine_transform
54
- self.affine_transform_batch = affine_transform_batch
55
53
 
56
54
  itype = f"int{self.datatype_bytes(int_dtype) * 8}"
57
55
  ftype = f"float{self.datatype_bytes(float_dtype) * 8}"
@@ -157,8 +155,8 @@ class CupyBackend(NumpyFFTWBackend):
157
155
 
158
156
  from voltools import StaticVolume
159
157
 
160
- # Only keep template and potential corresponding mask in cache
161
- if len(TEXTURE_CACHE) >= 2:
158
+ # Only keep template, mask and noise template in cache
159
+ if len(TEXTURE_CACHE) >= 3:
162
160
  TEXTURE_CACHE.clear()
163
161
 
164
162
  interpolation = "filt_bspline"
@@ -174,7 +172,7 @@ class CupyBackend(NumpyFFTWBackend):
174
172
 
175
173
  return TEXTURE_CACHE[key]
176
174
 
177
- def _rigid_transform(
175
+ def _transform(
178
176
  self,
179
177
  data: CupyArray,
180
178
  matrix: CupyArray,
@@ -182,21 +180,10 @@ class CupyBackend(NumpyFFTWBackend):
182
180
  prefilter: bool,
183
181
  order: int,
184
182
  cache: bool = False,
185
- batched: bool = False,
186
- ) -> None:
183
+ ) -> CupyArray:
187
184
  out_slice = tuple(slice(0, stop) for stop in data.shape)
188
- if batched:
189
- self.affine_transform_batch(
190
- input=data,
191
- matrix=matrix,
192
- mode="constant",
193
- output=output[out_slice],
194
- order=order,
195
- prefilter=prefilter,
196
- )
197
- return None
198
185
 
199
- if data.ndim == 3 and cache and self.texture_available and not batched:
186
+ if data.ndim == 3 and cache and self.texture_available:
200
187
  # Device memory pool (should) come to rescue performance
201
188
  temp = self.zeros(data.shape, data.dtype)
202
189
  texture = self._get_texture(data, order=order, prefilter=prefilter)
@@ -204,7 +191,7 @@ class CupyBackend(NumpyFFTWBackend):
204
191
  output[out_slice] = temp
205
192
  return None
206
193
 
207
- self.affine_transform(
194
+ return self.affine_transform(
208
195
  input=data,
209
196
  matrix=matrix,
210
197
  mode="constant",