pytme 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.2.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/match_template.py +213 -196
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/postprocess.py +40 -78
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/preprocess.py +4 -5
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/preprocessor_gui.py +49 -103
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/pytme_runner.py +46 -69
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/METADATA +3 -2
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/RECORD +68 -65
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +213 -196
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +40 -78
- scripts/preprocess.py +4 -5
- scripts/preprocessor_gui.py +49 -103
- scripts/pytme_runner.py +46 -69
- tests/preprocessing/test_compose.py +31 -30
- tests/preprocessing/test_frequency_filters.py +17 -32
- tests/preprocessing/test_preprocessor.py +0 -19
- tests/preprocessing/test_utils.py +13 -1
- tests/test_analyzer.py +2 -10
- tests/test_backends.py +47 -18
- tests/test_density.py +72 -13
- tests/test_extensions.py +1 -0
- tests/test_matching_cli.py +23 -9
- tests/test_matching_exhaustive.py +5 -5
- tests/test_matching_utils.py +3 -3
- tests/test_orientations.py +12 -0
- tests/test_rotations.py +13 -23
- tests/test_structure.py +1 -7
- tme/__version__.py +1 -1
- tme/analyzer/aggregation.py +47 -16
- tme/analyzer/base.py +34 -0
- tme/analyzer/peaks.py +26 -13
- tme/analyzer/proxy.py +14 -0
- tme/backends/_jax_utils.py +91 -68
- tme/backends/cupy_backend.py +6 -19
- tme/backends/jax_backend.py +103 -98
- tme/backends/matching_backend.py +0 -17
- tme/backends/mlx_backend.py +0 -29
- tme/backends/npfftw_backend.py +100 -97
- tme/backends/pytorch_backend.py +65 -78
- tme/cli.py +2 -2
- tme/density.py +44 -57
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/_utils.py +52 -24
- tme/filters/bandpass.py +99 -105
- tme/filters/compose.py +133 -39
- tme/filters/ctf.py +51 -102
- tme/filters/reconstruction.py +67 -122
- tme/filters/wedge.py +296 -325
- tme/filters/whitening.py +39 -75
- tme/mask.py +2 -2
- tme/matching_data.py +87 -15
- tme/matching_exhaustive.py +70 -120
- tme/matching_optimization.py +9 -63
- tme/matching_scores.py +261 -100
- tme/matching_utils.py +150 -91
- tme/memory.py +1 -0
- tme/orientations.py +17 -3
- tme/preprocessor.py +0 -239
- tme/rotations.py +102 -70
- tme/structure.py +601 -631
- tme/types.py +1 -0
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/WHEEL +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/top_level.txt +0 -0
tme/analyzer/peaks.py
CHANGED
@@ -18,7 +18,6 @@ from .base import AbstractAnalyzer
|
|
18
18
|
from ._utils import score_to_cart
|
19
19
|
from ..backends import backend as be
|
20
20
|
from ..types import BackendArray, NDArray
|
21
|
-
from ..rotations import euler_to_rotationmatrix
|
22
21
|
from ..matching_utils import split_shape, compute_extraction_box
|
23
22
|
|
24
23
|
__all__ = [
|
@@ -182,6 +181,7 @@ class PeakCaller(AbstractAnalyzer):
|
|
182
181
|
min_score: float = None,
|
183
182
|
max_score: float = None,
|
184
183
|
batch_dims: Tuple[int] = None,
|
184
|
+
projection_dims: Tuple[int] = None,
|
185
185
|
shm_handler: object = None,
|
186
186
|
**kwargs,
|
187
187
|
):
|
@@ -197,9 +197,13 @@ class PeakCaller(AbstractAnalyzer):
|
|
197
197
|
self.min_distance = int(min_distance)
|
198
198
|
self.min_boundary_distance = int(min_boundary_distance)
|
199
199
|
|
200
|
-
self.batch_dims =
|
200
|
+
self.batch_dims = ()
|
201
201
|
if batch_dims is not None:
|
202
|
-
self.batch_dims = tuple(int(x) for x in
|
202
|
+
self.batch_dims = tuple(int(x) for x in batch_dims)
|
203
|
+
|
204
|
+
self.projection_dims = ()
|
205
|
+
if projection_dims is not None:
|
206
|
+
self.projection_dims = tuple(int(x) for x in projection_dims)
|
203
207
|
|
204
208
|
self.min_score, self.max_score = min_score, max_score
|
205
209
|
|
@@ -231,7 +235,7 @@ class PeakCaller(AbstractAnalyzer):
|
|
231
235
|
|
232
236
|
rdim = len(self.shape)
|
233
237
|
if self.batch_dims:
|
234
|
-
rdim
|
238
|
+
rdim = rdim - len(self.batch_dims) + len(self.projection_dims)
|
235
239
|
|
236
240
|
rotations = be.full(
|
237
241
|
(self.num_peaks, rdim, rdim), fill_value=0, dtype=be._float_dtype
|
@@ -388,6 +392,20 @@ class PeakCaller(AbstractAnalyzer):
|
|
388
392
|
|
389
393
|
return state
|
390
394
|
|
395
|
+
def correct_background(self, state, mean, inv_std=1, **kwargs):
|
396
|
+
arr_type = type(be.zeros((1,), be._float_dtype))
|
397
|
+
translations, rotations, scores, details = state
|
398
|
+
|
399
|
+
if isinstance(mean, arr_type):
|
400
|
+
mean = mean[tuple(be.astype(translations.T, int))]
|
401
|
+
scores = be.subtract(scores, mean, out=scores)
|
402
|
+
|
403
|
+
if isinstance(inv_std, arr_type):
|
404
|
+
inv_std = inv_std[tuple(be.astype(translations.T, int))]
|
405
|
+
scores = be.multiply(scores, inv_std, out=scores)
|
406
|
+
|
407
|
+
return translations, rotations, scores, details
|
408
|
+
|
391
409
|
@classmethod
|
392
410
|
def merge(cls, results=List[Tuple], **kwargs) -> Tuple:
|
393
411
|
"""
|
@@ -778,6 +796,9 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
778
796
|
mask = be.to_backend_array(mask)
|
779
797
|
mask_buffer = be.zeros(mask.shape, dtype=mask.dtype)
|
780
798
|
|
799
|
+
if min_score is None:
|
800
|
+
min_score = self.min_score
|
801
|
+
|
781
802
|
if min_score is None:
|
782
803
|
min_score = be.min(scores) - 1
|
783
804
|
|
@@ -849,15 +870,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
849
870
|
"""
|
850
871
|
if rotation_space is None or rotation_mapping is None:
|
851
872
|
return rotation_matrix
|
852
|
-
|
853
|
-
rotation = rotation_mapping[rotation_space[tuple(peak)]]
|
854
|
-
|
855
|
-
# Old versions of rotation mapping contained Euler angles
|
856
|
-
if rotation.ndim != 2:
|
857
|
-
rotation = be.to_backend_array(
|
858
|
-
euler_to_rotationmatrix(be.to_numpy_array(rotation))
|
859
|
-
)
|
860
|
-
return rotation
|
873
|
+
return rotation_mapping[rotation_space[tuple(peak)]]
|
861
874
|
|
862
875
|
|
863
876
|
class PeakCallerScipy(PeakCaller):
|
tme/analyzer/proxy.py
CHANGED
@@ -85,6 +85,16 @@ class StatelessSharedAnalyzerProxy:
|
|
85
85
|
final_state = tuple(self._shared_to_object(x) for x in final_state)
|
86
86
|
return self._analyzer.result(final_state, **kwargs)
|
87
87
|
|
88
|
+
def correct_background(self, state, *args, **kwargs):
|
89
|
+
if self._shared:
|
90
|
+
# Copy to not correct the internal score array across processes
|
91
|
+
backend_arr = type(be.zeros((1), dtype=be._float_dtype))
|
92
|
+
state = tuple(self._shared_to_object(x) for x in state)
|
93
|
+
state = tuple(
|
94
|
+
be.copy(x) if isinstance(x, backend_arr) else x for x in state
|
95
|
+
)
|
96
|
+
return self._analyzer.correct_background(state, *args, **kwargs)
|
97
|
+
|
88
98
|
def merge(self, *args, **kwargs):
|
89
99
|
return self._analyzer.merge(*args, **kwargs)
|
90
100
|
|
@@ -121,3 +131,7 @@ class SharedAnalyzerProxy(StatelessSharedAnalyzerProxy):
|
|
121
131
|
def result(self, **kwargs):
|
122
132
|
"""Extract final result"""
|
123
133
|
return super().result(self._state, **kwargs)
|
134
|
+
|
135
|
+
def correct_background(self, *args, **kwargs):
|
136
|
+
# We always assign to state as this operation can not be shared
|
137
|
+
self._state = super().correct_background(self._state, *args, **kwargs)
|
tme/backends/_jax_utils.py
CHANGED
@@ -10,11 +10,11 @@ from typing import Tuple
|
|
10
10
|
from functools import partial
|
11
11
|
|
12
12
|
import jax.numpy as jnp
|
13
|
-
from jax import pmap, lax,
|
13
|
+
from jax import pmap, lax, jit
|
14
14
|
|
15
15
|
from ..types import BackendArray
|
16
16
|
from ..backends import backend as be
|
17
|
-
from ..matching_utils import
|
17
|
+
from ..matching_utils import standardize, to_padded
|
18
18
|
|
19
19
|
|
20
20
|
__all__ = ["scan", "setup_scan"]
|
@@ -62,15 +62,14 @@ def _flcSphere_scoring(
|
|
62
62
|
Computes :py:meth:`tme.matching_scores.corr_scoring`.
|
63
63
|
"""
|
64
64
|
correlation = _correlate(template=template, ft_target=ft_target)
|
65
|
-
|
66
|
-
return correlation
|
65
|
+
return correlation.at[:].multiply(inv_denominator)
|
67
66
|
|
68
67
|
|
69
68
|
def _reciprocal_target_std(
|
70
69
|
ft_target: BackendArray,
|
71
70
|
ft_target2: BackendArray,
|
72
71
|
template_mask: BackendArray,
|
73
|
-
|
72
|
+
n_obs: float,
|
74
73
|
eps: float,
|
75
74
|
) -> BackendArray:
|
76
75
|
"""
|
@@ -80,16 +79,16 @@ def _reciprocal_target_std(
|
|
80
79
|
--------
|
81
80
|
:py:meth:`tme.matching_scores.flc_scoring`.
|
82
81
|
"""
|
83
|
-
|
84
|
-
ft_template_mask = jnp.fft.rfftn(template_mask, s=
|
82
|
+
shape = template_mask.shape
|
83
|
+
ft_template_mask = jnp.fft.rfftn(template_mask, s=shape)
|
85
84
|
|
86
85
|
# E(X^2)- E(X)^2
|
87
|
-
exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask, s=
|
88
|
-
exp_sq = exp_sq.at[:].divide(
|
86
|
+
exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask, s=shape)
|
87
|
+
exp_sq = exp_sq.at[:].divide(n_obs)
|
89
88
|
|
90
89
|
ft_template_mask = ft_template_mask.at[:].multiply(ft_target)
|
91
|
-
sq_exp = jnp.fft.irfftn(ft_template_mask, s=
|
92
|
-
sq_exp = sq_exp.at[:].divide(
|
90
|
+
sq_exp = jnp.fft.irfftn(ft_template_mask, s=shape)
|
91
|
+
sq_exp = sq_exp.at[:].divide(n_obs)
|
93
92
|
sq_exp = sq_exp.at[:].power(2)
|
94
93
|
|
95
94
|
exp_sq = exp_sq.at[:].add(-sq_exp)
|
@@ -97,7 +96,7 @@ def _reciprocal_target_std(
|
|
97
96
|
exp_sq = exp_sq.at[:].power(0.5)
|
98
97
|
|
99
98
|
exp_sq = exp_sq.at[:].set(
|
100
|
-
jnp.where(exp_sq <= eps, 0, jnp.reciprocal(exp_sq *
|
99
|
+
jnp.where(exp_sq <= eps, 0, jnp.reciprocal(exp_sq * n_obs))
|
101
100
|
)
|
102
101
|
return exp_sq
|
103
102
|
|
@@ -108,32 +107,21 @@ def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> Backen
|
|
108
107
|
return arr.at[:].set(jnp.fft.irfftn(arr_ft, s=arr.shape))
|
109
108
|
|
110
109
|
|
111
|
-
def
|
112
|
-
return arr
|
113
|
-
|
114
|
-
|
115
|
-
def _mask_scores(arr, mask):
|
116
|
-
return arr.at[:].multiply(mask)
|
117
|
-
|
118
|
-
|
119
|
-
def _select_config(analyzer_kwargs, device_idx):
|
120
|
-
return analyzer_kwargs[device_idx]
|
121
|
-
|
122
|
-
|
123
|
-
def setup_scan(analyzer_kwargs, callback_class, fast_shape, rotate_mask):
|
110
|
+
def setup_scan(analyzer_kwargs, analyzer, fast_shape, rotate_mask, match_projection):
|
124
111
|
"""Create separate scan function with initialized analyzer for each device"""
|
125
112
|
device_scans = [
|
126
113
|
partial(
|
127
114
|
scan,
|
128
115
|
fast_shape=fast_shape,
|
129
116
|
rotate_mask=rotate_mask,
|
130
|
-
analyzer=
|
131
|
-
)
|
117
|
+
analyzer=analyzer(**device_config),
|
118
|
+
)
|
119
|
+
for device_config in analyzer_kwargs
|
132
120
|
]
|
133
121
|
|
134
122
|
@partial(
|
135
123
|
pmap,
|
136
|
-
in_axes=(0,) + (None,) *
|
124
|
+
in_axes=(0,) + (None,) * 7,
|
137
125
|
axis_name="batch",
|
138
126
|
)
|
139
127
|
def scan_combined(
|
@@ -144,6 +132,7 @@ def setup_scan(analyzer_kwargs, callback_class, fast_shape, rotate_mask):
|
|
144
132
|
template_filter,
|
145
133
|
target_filter,
|
146
134
|
score_mask,
|
135
|
+
background_template,
|
147
136
|
):
|
148
137
|
return lax.switch(
|
149
138
|
lax.axis_index("batch"),
|
@@ -155,10 +144,13 @@ def setup_scan(analyzer_kwargs, callback_class, fast_shape, rotate_mask):
|
|
155
144
|
template_filter,
|
156
145
|
target_filter,
|
157
146
|
score_mask,
|
147
|
+
background_template,
|
158
148
|
)
|
149
|
+
|
159
150
|
return scan_combined
|
160
151
|
|
161
152
|
|
153
|
+
@partial(jit, static_argnums=(8, 9, 10))
|
162
154
|
def scan(
|
163
155
|
target: BackendArray,
|
164
156
|
template: BackendArray,
|
@@ -167,67 +159,98 @@ def scan(
|
|
167
159
|
template_filter: BackendArray,
|
168
160
|
target_filter: BackendArray,
|
169
161
|
score_mask: BackendArray,
|
162
|
+
background_template: BackendArray,
|
170
163
|
fast_shape: Tuple[int],
|
171
164
|
rotate_mask: bool,
|
172
165
|
analyzer: object,
|
173
|
-
) -> Tuple
|
166
|
+
) -> Tuple:
|
174
167
|
eps = jnp.finfo(template.dtype).resolution
|
175
168
|
|
176
|
-
if
|
169
|
+
if target_filter.shape != ():
|
177
170
|
target = _apply_fourier_filter(target, target_filter)
|
178
171
|
|
179
172
|
ft_target = jnp.fft.rfftn(target, s=fast_shape)
|
180
173
|
ft_target2 = jnp.fft.rfftn(jnp.square(target), s=fast_shape)
|
181
|
-
|
174
|
+
_n_obs, _inv_denominator, target = None, None, None
|
175
|
+
|
176
|
+
unpadded_slice = tuple(slice(0, x) for x in template.shape)
|
177
|
+
rot_buffer, mask_rot_buffer = jnp.zeros(fast_shape), jnp.zeros(fast_shape)
|
182
178
|
if not rotate_mask:
|
183
|
-
|
184
|
-
|
179
|
+
_n_obs = jnp.sum(template_mask)
|
180
|
+
_inv_denominator = _reciprocal_target_std(
|
185
181
|
ft_target=ft_target,
|
186
182
|
ft_target2=ft_target2,
|
187
|
-
template_mask=
|
183
|
+
template_mask=to_padded(mask_rot_buffer, template_mask, unpadded_slice),
|
188
184
|
eps=eps,
|
189
|
-
|
185
|
+
n_obs=_n_obs,
|
190
186
|
)
|
191
|
-
ft_target2
|
187
|
+
ft_target2 = None
|
192
188
|
|
193
|
-
|
194
|
-
|
195
|
-
|
189
|
+
mask_scores = score_mask.shape != ()
|
190
|
+
filter_template = template_filter.shape != ()
|
191
|
+
bg_correction = background_template.shape != ()
|
192
|
+
bg_scores = jnp.zeros(fast_shape) if bg_correction else 0
|
196
193
|
|
197
|
-
|
198
|
-
|
199
|
-
|
194
|
+
_template_mask_rot = template_mask
|
195
|
+
template_indices = be._index_grid(template.shape)
|
196
|
+
center = be.divide(be.to_backend_array(template.shape) - 1, 2)
|
200
197
|
|
201
198
|
def _sample_transform(ret, rotation_matrix):
|
202
|
-
|
203
|
-
|
204
|
-
arr=template,
|
205
|
-
arr_mask=template_mask,
|
206
|
-
rotation_matrix=rotation_matrix,
|
207
|
-
order=1, # thats all we get for now
|
199
|
+
matrix = be._build_transform_matrix(
|
200
|
+
rotation_matrix=rotation_matrix, center=center
|
208
201
|
)
|
202
|
+
indices = be._transform_indices(template_indices, matrix)
|
203
|
+
|
204
|
+
template_rot = be._interpolate(template, indices, order=1)
|
205
|
+
n_obs, template_mask_rot = _n_obs, _template_mask_rot
|
206
|
+
if rotate_mask:
|
207
|
+
template_mask_rot = be._interpolate(template_mask, indices, order=1)
|
208
|
+
n_obs = jnp.sum(template_mask_rot)
|
209
|
+
|
210
|
+
if filter_template:
|
211
|
+
template_rot = _apply_fourier_filter(template_rot, template_filter)
|
212
|
+
template_rot = standardize(template_rot, template_mask_rot, n_obs)
|
213
|
+
|
214
|
+
rot_pad = to_padded(rot_buffer, template_rot, unpadded_slice)
|
215
|
+
|
216
|
+
inv_denominator = _inv_denominator
|
217
|
+
if rotate_mask:
|
218
|
+
mask_rot_pad = to_padded(mask_rot_buffer, template_mask_rot, unpadded_slice)
|
219
|
+
inv_denominator = _reciprocal_target_std(
|
220
|
+
ft_target=ft_target,
|
221
|
+
ft_target2=ft_target2,
|
222
|
+
template_mask=mask_rot_pad,
|
223
|
+
n_obs=n_obs,
|
224
|
+
eps=eps,
|
225
|
+
)
|
226
|
+
|
227
|
+
scores = _flcSphere_scoring(ft_target, rot_pad, inv_denominator)
|
228
|
+
if mask_scores:
|
229
|
+
scores = scores.at[:].multiply(score_mask)
|
230
|
+
|
231
|
+
state, bg_scores, index = ret
|
232
|
+
state = analyzer(state, scores, rotation_matrix, rotation_index=index)
|
209
233
|
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
rot_pad = be.topleft_pad(template_rot, fast_shape)
|
216
|
-
mask_rot_pad = be.topleft_pad(template_mask_rot, fast_shape)
|
234
|
+
if bg_correction:
|
235
|
+
template_rot = be._interpolate(background_template, indices, order=1)
|
236
|
+
if filter_template:
|
237
|
+
template_rot = _apply_fourier_filter(template_rot, template_filter)
|
238
|
+
template_rot = standardize(template_rot, template_mask_rot, n_obs)
|
217
239
|
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
ft_target=ft_target,
|
222
|
-
ft_target2=ft_target2,
|
223
|
-
inv_denominator=inv_denominator,
|
224
|
-
n_observations=n_observations,
|
225
|
-
eps=eps,
|
226
|
-
)
|
227
|
-
scores = _score_mask_func(scores, score_mask)
|
240
|
+
rot_pad = to_padded(rot_buffer, template_rot, unpadded_slice)
|
241
|
+
scores = _flcSphere_scoring(ft_target, rot_pad, inv_denominator)
|
242
|
+
bg_scores = jnp.maximum(bg_scores, scores)
|
228
243
|
|
229
|
-
|
230
|
-
|
244
|
+
return (state, bg_scores, index + 1), None
|
245
|
+
|
246
|
+
(state, bg_scores, _), _ = lax.scan(
|
247
|
+
_sample_transform, (analyzer.init_state(), bg_scores, 0), rotations
|
248
|
+
)
|
249
|
+
|
250
|
+
if bg_correction:
|
251
|
+
if mask_scores:
|
252
|
+
bg_scores = bg_scores.at[:].multiply(score_mask)
|
253
|
+
bg_scores = bg_scores.at[:].add(-be.mean(bg_scores))
|
254
|
+
state = analyzer.correct_background(state, bg_scores)
|
231
255
|
|
232
|
-
(state, _), _ = lax.scan(_sample_transform, (analyzer.init_state(), 0), rotations)
|
233
256
|
return state
|
tme/backends/cupy_backend.py
CHANGED
@@ -33,7 +33,6 @@ class CupyBackend(NumpyFFTWBackend):
|
|
33
33
|
import cupy as cp
|
34
34
|
import cupyx.scipy.fft as cufft
|
35
35
|
from cupyx.scipy.ndimage import affine_transform, maximum_filter
|
36
|
-
from ._cupy_utils import affine_transform_batch
|
37
36
|
|
38
37
|
float_dtype = cp.float32 if float_dtype is None else float_dtype
|
39
38
|
complex_dtype = cp.complex64 if complex_dtype is None else complex_dtype
|
@@ -51,7 +50,6 @@ class CupyBackend(NumpyFFTWBackend):
|
|
51
50
|
self._cufft = cufft
|
52
51
|
self.maximum_filter = maximum_filter
|
53
52
|
self.affine_transform = affine_transform
|
54
|
-
self.affine_transform_batch = affine_transform_batch
|
55
53
|
|
56
54
|
itype = f"int{self.datatype_bytes(int_dtype) * 8}"
|
57
55
|
ftype = f"float{self.datatype_bytes(float_dtype) * 8}"
|
@@ -157,8 +155,8 @@ class CupyBackend(NumpyFFTWBackend):
|
|
157
155
|
|
158
156
|
from voltools import StaticVolume
|
159
157
|
|
160
|
-
# Only keep template and
|
161
|
-
if len(TEXTURE_CACHE) >=
|
158
|
+
# Only keep template, mask and noise template in cache
|
159
|
+
if len(TEXTURE_CACHE) >= 3:
|
162
160
|
TEXTURE_CACHE.clear()
|
163
161
|
|
164
162
|
interpolation = "filt_bspline"
|
@@ -174,7 +172,7 @@ class CupyBackend(NumpyFFTWBackend):
|
|
174
172
|
|
175
173
|
return TEXTURE_CACHE[key]
|
176
174
|
|
177
|
-
def
|
175
|
+
def _transform(
|
178
176
|
self,
|
179
177
|
data: CupyArray,
|
180
178
|
matrix: CupyArray,
|
@@ -182,21 +180,10 @@ class CupyBackend(NumpyFFTWBackend):
|
|
182
180
|
prefilter: bool,
|
183
181
|
order: int,
|
184
182
|
cache: bool = False,
|
185
|
-
|
186
|
-
) -> None:
|
183
|
+
) -> CupyArray:
|
187
184
|
out_slice = tuple(slice(0, stop) for stop in data.shape)
|
188
|
-
if batched:
|
189
|
-
self.affine_transform_batch(
|
190
|
-
input=data,
|
191
|
-
matrix=matrix,
|
192
|
-
mode="constant",
|
193
|
-
output=output[out_slice],
|
194
|
-
order=order,
|
195
|
-
prefilter=prefilter,
|
196
|
-
)
|
197
|
-
return None
|
198
185
|
|
199
|
-
if data.ndim == 3 and cache and self.texture_available
|
186
|
+
if data.ndim == 3 and cache and self.texture_available:
|
200
187
|
# Device memory pool (should) come to rescue performance
|
201
188
|
temp = self.zeros(data.shape, data.dtype)
|
202
189
|
texture = self._get_texture(data, order=order, prefilter=prefilter)
|
@@ -204,7 +191,7 @@ class CupyBackend(NumpyFFTWBackend):
|
|
204
191
|
output[out_slice] = temp
|
205
192
|
return None
|
206
193
|
|
207
|
-
self.affine_transform(
|
194
|
+
return self.affine_transform(
|
208
195
|
input=data,
|
209
196
|
matrix=matrix,
|
210
197
|
mode="constant",
|