pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
- pytme-0.2.9.data/scripts/match_template.py +1135 -0
- pytme-0.2.9.data/scripts/postprocess.py +622 -0
- pytme-0.2.9.data/scripts/preprocess.py +209 -0
- pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
- pytme-0.2.9.dist-info/METADATA +95 -0
- pytme-0.2.9.dist-info/RECORD +119 -0
- pytme-0.2.9.dist-info/WHEEL +5 -0
- pytme-0.2.9.dist-info/entry_points.txt +6 -0
- pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
- pytme-0.2.9.dist-info/top_level.txt +3 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +1135 -0
- scripts/postprocess.py +622 -0
- scripts/preprocess.py +209 -0
- scripts/preprocessor_gui.py +1227 -0
- tests/__init__.py +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +216 -0
- tests/test_backends.py +446 -0
- tests/test_density.py +503 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +124 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +226 -0
- tests/test_matching_utils.py +189 -0
- tests/test_orientations.py +175 -0
- tests/test_parser.py +33 -0
- tests/test_rotations.py +153 -0
- tests/test_structure.py +247 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +186 -0
- tme/analyzer/aggregation.py +577 -0
- tme/analyzer/peaks.py +953 -0
- tme/backends/__init__.py +171 -0
- tme/backends/_cupy_utils.py +734 -0
- tme/backends/_jax_utils.py +188 -0
- tme/backends/cupy_backend.py +294 -0
- tme/backends/jax_backend.py +314 -0
- tme/backends/matching_backend.py +1270 -0
- tme/backends/mlx_backend.py +241 -0
- tme/backends/npfftw_backend.py +583 -0
- tme/backends/pytorch_backend.py +430 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2263 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/filters/__init__.py +6 -0
- tme/filters/_utils.py +311 -0
- tme/filters/bandpass.py +230 -0
- tme/filters/compose.py +81 -0
- tme/filters/ctf.py +393 -0
- tme/filters/reconstruction.py +160 -0
- tme/filters/wedge.py +542 -0
- tme/filters/whitening.py +191 -0
- tme/matching_data.py +863 -0
- tme/matching_exhaustive.py +497 -0
- tme/matching_optimization.py +1311 -0
- tme/matching_scores.py +1183 -0
- tme/matching_utils.py +1188 -0
- tme/memory.py +337 -0
- tme/orientations.py +598 -0
- tme/parser.py +685 -0
- tme/preprocessor.py +1329 -0
- tme/rotations.py +350 -0
- tme/structure.py +1864 -0
- tme/types.py +13 -0
@@ -0,0 +1,188 @@
|
|
1
|
+
""" Utility functions for jax backend.
|
2
|
+
|
3
|
+
Copyright (c) 2023-2024 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import Tuple
|
9
|
+
from functools import partial
|
10
|
+
|
11
|
+
import jax.numpy as jnp
|
12
|
+
from jax import pmap, lax
|
13
|
+
|
14
|
+
from ..types import BackendArray
|
15
|
+
from ..backends import backend as be
|
16
|
+
from ..matching_utils import normalize_template as _normalize_template
|
17
|
+
|
18
|
+
|
19
|
+
def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
|
20
|
+
"""
|
21
|
+
Computes :py:meth:`tme.matching_exhaustive.cc_setup`.
|
22
|
+
"""
|
23
|
+
template_ft = jnp.fft.rfftn(template, s=template.shape)
|
24
|
+
template_ft = template_ft.at[:].multiply(ft_target)
|
25
|
+
correlation = jnp.fft.irfftn(template_ft, s=template.shape)
|
26
|
+
return correlation
|
27
|
+
|
28
|
+
|
29
|
+
def _flc_scoring(
|
30
|
+
template: BackendArray,
|
31
|
+
template_mask: BackendArray,
|
32
|
+
ft_target: BackendArray,
|
33
|
+
ft_target2: BackendArray,
|
34
|
+
n_observations: BackendArray,
|
35
|
+
eps: float,
|
36
|
+
**kwargs,
|
37
|
+
) -> BackendArray:
|
38
|
+
"""
|
39
|
+
Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
|
40
|
+
"""
|
41
|
+
correlation = _correlate(template=template, ft_target=ft_target)
|
42
|
+
inv_denominator = _reciprocal_target_std(
|
43
|
+
ft_target=ft_target,
|
44
|
+
ft_target2=ft_target2,
|
45
|
+
template_mask=template_mask,
|
46
|
+
eps=eps,
|
47
|
+
n_observations=n_observations,
|
48
|
+
)
|
49
|
+
correlation = correlation.at[:].multiply(inv_denominator)
|
50
|
+
return correlation
|
51
|
+
|
52
|
+
|
53
|
+
def _flcSphere_scoring(
|
54
|
+
template: BackendArray,
|
55
|
+
ft_target: BackendArray,
|
56
|
+
inv_denominator: BackendArray,
|
57
|
+
**kwargs,
|
58
|
+
) -> BackendArray:
|
59
|
+
"""
|
60
|
+
Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
|
61
|
+
"""
|
62
|
+
correlation = _correlate(template=template, ft_target=ft_target)
|
63
|
+
correlation = correlation.at[:].multiply(inv_denominator)
|
64
|
+
return correlation
|
65
|
+
|
66
|
+
|
67
|
+
def _reciprocal_target_std(
|
68
|
+
ft_target: BackendArray,
|
69
|
+
ft_target2: BackendArray,
|
70
|
+
template_mask: BackendArray,
|
71
|
+
n_observations: float,
|
72
|
+
eps: float,
|
73
|
+
) -> BackendArray:
|
74
|
+
"""
|
75
|
+
Computes reciprocal standard deviation of a target given a mask.
|
76
|
+
|
77
|
+
See Also
|
78
|
+
--------
|
79
|
+
:py:meth:`tme.matching_exhaustive.flc_scoring`.
|
80
|
+
"""
|
81
|
+
ft_shape = template_mask.shape
|
82
|
+
ft_template_mask = jnp.fft.rfftn(template_mask, s=ft_shape)
|
83
|
+
|
84
|
+
# E(X^2)- E(X)^2
|
85
|
+
exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask, s=ft_shape)
|
86
|
+
exp_sq = exp_sq.at[:].divide(n_observations)
|
87
|
+
|
88
|
+
ft_template_mask = ft_template_mask.at[:].multiply(ft_target)
|
89
|
+
sq_exp = jnp.fft.irfftn(ft_template_mask, s=ft_shape)
|
90
|
+
sq_exp = sq_exp.at[:].divide(n_observations)
|
91
|
+
sq_exp = sq_exp.at[:].power(2)
|
92
|
+
|
93
|
+
exp_sq = exp_sq.at[:].add(-sq_exp)
|
94
|
+
exp_sq = exp_sq.at[:].max(0)
|
95
|
+
exp_sq = exp_sq.at[:].power(0.5)
|
96
|
+
|
97
|
+
exp_sq = exp_sq.at[:].set(
|
98
|
+
jnp.where(exp_sq <= eps, 0, jnp.reciprocal(exp_sq * n_observations))
|
99
|
+
)
|
100
|
+
return exp_sq
|
101
|
+
|
102
|
+
|
103
|
+
def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
|
104
|
+
arr_ft = jnp.fft.rfftn(arr, s=arr.shape)
|
105
|
+
arr_ft = arr_ft.at[:].multiply(arr_filter)
|
106
|
+
return arr.at[:].set(jnp.fft.irfftn(arr_ft, s=arr.shape))
|
107
|
+
|
108
|
+
|
109
|
+
def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
|
110
|
+
return arr
|
111
|
+
|
112
|
+
|
113
|
+
@partial(
|
114
|
+
pmap,
|
115
|
+
in_axes=(0,) + (None,) * 6,
|
116
|
+
static_broadcasted_argnums=[6, 7],
|
117
|
+
)
|
118
|
+
def scan(
|
119
|
+
target: BackendArray,
|
120
|
+
template: BackendArray,
|
121
|
+
template_mask: BackendArray,
|
122
|
+
rotations: BackendArray,
|
123
|
+
template_filter: BackendArray,
|
124
|
+
target_filter: BackendArray,
|
125
|
+
fast_shape: Tuple[int],
|
126
|
+
rotate_mask: bool,
|
127
|
+
) -> Tuple[BackendArray, BackendArray]:
|
128
|
+
eps = jnp.finfo(template.dtype).resolution
|
129
|
+
|
130
|
+
if hasattr(target_filter, "shape"):
|
131
|
+
target = _apply_fourier_filter(target, target_filter)
|
132
|
+
|
133
|
+
ft_target = jnp.fft.rfftn(target, s=fast_shape)
|
134
|
+
ft_target2 = jnp.fft.rfftn(jnp.square(target), s=fast_shape)
|
135
|
+
inv_denominator, target, scoring_func = None, None, _flc_scoring
|
136
|
+
if not rotate_mask:
|
137
|
+
n_observations = jnp.sum(template_mask)
|
138
|
+
inv_denominator = _reciprocal_target_std(
|
139
|
+
ft_target=ft_target,
|
140
|
+
ft_target2=ft_target2,
|
141
|
+
template_mask=be.topleft_pad(template_mask, fast_shape),
|
142
|
+
eps=eps,
|
143
|
+
n_observations=n_observations,
|
144
|
+
)
|
145
|
+
ft_target2, scoring_func = None, _flcSphere_scoring
|
146
|
+
|
147
|
+
_template_filter_func = _identity
|
148
|
+
if template_filter.shape != ():
|
149
|
+
_template_filter_func = _apply_fourier_filter
|
150
|
+
|
151
|
+
def _sample_transform(ret, rotation_matrix):
|
152
|
+
max_scores, rotations, index = ret
|
153
|
+
template_rot, template_mask_rot = be.rigid_transform(
|
154
|
+
arr=template,
|
155
|
+
arr_mask=template_mask,
|
156
|
+
rotation_matrix=rotation_matrix,
|
157
|
+
order=1, # thats all we get for now
|
158
|
+
)
|
159
|
+
|
160
|
+
n_observations = jnp.sum(template_mask_rot)
|
161
|
+
template_rot = _template_filter_func(template_rot, template_filter)
|
162
|
+
template_rot = _normalize_template(
|
163
|
+
template_rot, template_mask_rot, n_observations
|
164
|
+
)
|
165
|
+
template_rot = be.topleft_pad(template_rot, fast_shape)
|
166
|
+
template_mask_rot = be.topleft_pad(template_mask_rot, fast_shape)
|
167
|
+
|
168
|
+
scores = scoring_func(
|
169
|
+
template=template_rot,
|
170
|
+
template_mask=template_mask_rot,
|
171
|
+
ft_target=ft_target,
|
172
|
+
ft_target2=ft_target2,
|
173
|
+
inv_denominator=inv_denominator,
|
174
|
+
n_observations=n_observations,
|
175
|
+
eps=eps,
|
176
|
+
)
|
177
|
+
max_scores, rotations = be.max_score_over_rotations(
|
178
|
+
scores, max_scores, rotations, index
|
179
|
+
)
|
180
|
+
return (max_scores, rotations, index + 1), None
|
181
|
+
|
182
|
+
score_space = jnp.zeros(fast_shape)
|
183
|
+
rotation_space = jnp.full(shape=fast_shape, dtype=jnp.int32, fill_value=-1)
|
184
|
+
(score_space, rotation_space, _), _ = lax.scan(
|
185
|
+
_sample_transform, (score_space, rotation_space, 0), rotations
|
186
|
+
)
|
187
|
+
|
188
|
+
return score_space, rotation_space
|
@@ -0,0 +1,294 @@
|
|
1
|
+
""" Backend using cupy for template matching.
|
2
|
+
|
3
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
import warnings
|
9
|
+
from importlib.util import find_spec
|
10
|
+
from contextlib import contextmanager
|
11
|
+
from typing import Tuple, Callable, List
|
12
|
+
|
13
|
+
import numpy as np
|
14
|
+
|
15
|
+
from .npfftw_backend import NumpyFFTWBackend
|
16
|
+
from ..types import CupyArray, NDArray, shm_type
|
17
|
+
|
18
|
+
PLAN_CACHE = {}
|
19
|
+
TEXTURE_CACHE = {}
|
20
|
+
|
21
|
+
|
22
|
+
class CupyBackend(NumpyFFTWBackend):
|
23
|
+
"""
|
24
|
+
A cupy-based matching backend.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
float_dtype: type = None,
|
30
|
+
complex_dtype: type = None,
|
31
|
+
int_dtype: type = None,
|
32
|
+
overflow_safe_dtype: type = None,
|
33
|
+
**kwargs,
|
34
|
+
):
|
35
|
+
import cupy as cp
|
36
|
+
import cupyx.scipy.fft as cufft
|
37
|
+
from cupyx.scipy.ndimage import affine_transform, maximum_filter
|
38
|
+
from ._cupy_utils import affine_transform_batch
|
39
|
+
|
40
|
+
float_dtype = cp.float32 if float_dtype is None else float_dtype
|
41
|
+
complex_dtype = cp.complex64 if complex_dtype is None else complex_dtype
|
42
|
+
int_dtype = cp.int32 if int_dtype is None else int_dtype
|
43
|
+
if overflow_safe_dtype is None:
|
44
|
+
overflow_safe_dtype = cp.float32
|
45
|
+
|
46
|
+
super().__init__(
|
47
|
+
array_backend=cp,
|
48
|
+
float_dtype=float_dtype,
|
49
|
+
complex_dtype=complex_dtype,
|
50
|
+
int_dtype=int_dtype,
|
51
|
+
overflow_safe_dtype=overflow_safe_dtype,
|
52
|
+
)
|
53
|
+
self._cufft = cufft
|
54
|
+
self.maximum_filter = maximum_filter
|
55
|
+
self.affine_transform = affine_transform
|
56
|
+
self.affine_transform_batch = affine_transform_batch
|
57
|
+
|
58
|
+
itype = f"int{self.datatype_bytes(int_dtype) * 8}"
|
59
|
+
ftype = f"float{self.datatype_bytes(float_dtype) * 8}"
|
60
|
+
self._max_score_over_rotations = self._array_backend.ElementwiseKernel(
|
61
|
+
f"{ftype} internal_scores, {ftype} scores, {itype} rot_index",
|
62
|
+
f"{ftype} out1, {itype} rotations",
|
63
|
+
"if (internal_scores < scores) {out1 = scores; rotations = rot_index;}",
|
64
|
+
"max_score_over_rotations",
|
65
|
+
)
|
66
|
+
self.norm_scores = cp.ElementwiseKernel(
|
67
|
+
f"{ftype} arr, {ftype} exp_sq, {ftype} sq_exp, {ftype} n_obs, {ftype} eps",
|
68
|
+
f"{ftype} out",
|
69
|
+
"""
|
70
|
+
// tmp1 = E(X)^2; tmp2 = E(X^2)
|
71
|
+
float tmp1 = sq_exp / n_obs;
|
72
|
+
float tmp2 = exp_sq / n_obs;
|
73
|
+
tmp1 *= tmp1;
|
74
|
+
|
75
|
+
tmp2 = sqrt(max(tmp2 - tmp1, 0.0));
|
76
|
+
// out = (tmp2 < eps) ? 0.0 : arr / (tmp2 * n_obs);
|
77
|
+
tmp1 = arr;
|
78
|
+
if (tmp2 < eps){
|
79
|
+
tmp1 = 0;
|
80
|
+
}
|
81
|
+
tmp2 *= n_obs;
|
82
|
+
out = tmp1 / tmp2;
|
83
|
+
""",
|
84
|
+
"norm_scores",
|
85
|
+
)
|
86
|
+
self.texture_available = find_spec("voltools") is not None
|
87
|
+
|
88
|
+
def to_backend_array(self, arr: NDArray) -> CupyArray:
|
89
|
+
current_device = self._array_backend.cuda.device.get_device_id()
|
90
|
+
if (
|
91
|
+
isinstance(arr, self._array_backend.ndarray)
|
92
|
+
and arr.device.id == current_device
|
93
|
+
):
|
94
|
+
return arr
|
95
|
+
return self._array_backend.asarray(arr)
|
96
|
+
|
97
|
+
def to_numpy_array(self, arr: CupyArray) -> NDArray:
|
98
|
+
return self._array_backend.asnumpy(arr)
|
99
|
+
|
100
|
+
def to_cpu_array(self, arr: NDArray) -> NDArray:
|
101
|
+
return self.to_numpy_array(arr)
|
102
|
+
|
103
|
+
def from_sharedarr(self, arr: CupyArray) -> CupyArray:
|
104
|
+
return arr
|
105
|
+
|
106
|
+
@staticmethod
|
107
|
+
def to_sharedarr(arr: CupyArray, shared_memory_handler: type = None) -> shm_type:
|
108
|
+
return arr
|
109
|
+
|
110
|
+
def zeros(self, *args, **kwargs):
|
111
|
+
return self._array_backend.zeros(*args, **kwargs)
|
112
|
+
|
113
|
+
def unravel_index(self, indices, shape):
|
114
|
+
return self._array_backend.unravel_index(indices=indices, dims=shape)
|
115
|
+
|
116
|
+
def unique(self, ar, axis=None, *args, **kwargs):
|
117
|
+
if axis is None:
|
118
|
+
return self._array_backend.unique(ar=ar, axis=axis, *args, **kwargs)
|
119
|
+
|
120
|
+
warnings.warn("Axis argument not yet supported in CupY, falling back to NumPy.")
|
121
|
+
ret = np.unique(ar=self.to_numpy_array(ar), axis=axis, *args, **kwargs)
|
122
|
+
if not isinstance(ret, tuple):
|
123
|
+
return self.to_backend_array(ret)
|
124
|
+
return tuple(self.to_backend_array(k) for k in ret)
|
125
|
+
|
126
|
+
def build_fft(
|
127
|
+
self,
|
128
|
+
fwd_shape: Tuple[int],
|
129
|
+
inv_shape: Tuple[int],
|
130
|
+
inv_output_shape: Tuple[int] = None,
|
131
|
+
fwd_axes: Tuple[int] = None,
|
132
|
+
inv_axes: Tuple[int] = None,
|
133
|
+
**kwargs,
|
134
|
+
) -> Tuple[Callable, Callable]:
|
135
|
+
cache = self._array_backend.fft.config.get_plan_cache()
|
136
|
+
current_device = self._array_backend.cuda.device.get_device_id()
|
137
|
+
|
138
|
+
previous_transform = [fwd_shape, inv_shape]
|
139
|
+
if current_device in PLAN_CACHE:
|
140
|
+
previous_transform = PLAN_CACHE[current_device]
|
141
|
+
|
142
|
+
real_diff, cmplx_diff = True, True
|
143
|
+
if len(fwd_shape) == len(previous_transform[0]):
|
144
|
+
real_diff = fwd_shape == previous_transform[0]
|
145
|
+
if len(inv_shape) == len(previous_transform[1]):
|
146
|
+
cmplx_diff = inv_shape == previous_transform[1]
|
147
|
+
|
148
|
+
if real_diff or cmplx_diff:
|
149
|
+
cache.clear()
|
150
|
+
|
151
|
+
rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
|
152
|
+
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
153
|
+
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
154
|
+
|
155
|
+
def rfftn(
|
156
|
+
arr: CupyArray, out: CupyArray = None, s=rfft_shape, axes=fwd_axes
|
157
|
+
) -> CupyArray:
|
158
|
+
return self.rfftn(arr, s=s, axes=fwd_axes)
|
159
|
+
|
160
|
+
def irfftn(
|
161
|
+
arr: CupyArray, out: CupyArray = None, s=irfft_shape, axes=inv_axes
|
162
|
+
) -> CupyArray:
|
163
|
+
return self.irfftn(arr, s=s, axes=inv_axes)
|
164
|
+
|
165
|
+
PLAN_CACHE[current_device] = [fwd_shape, inv_shape]
|
166
|
+
|
167
|
+
return rfftn, irfftn
|
168
|
+
|
169
|
+
def rfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
|
170
|
+
return self._cufft.rfftn(arr, **kwargs)
|
171
|
+
|
172
|
+
def irfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
|
173
|
+
return self._cufft.irfftn(arr, **kwargs)
|
174
|
+
|
175
|
+
def compute_convolution_shapes(
|
176
|
+
self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
|
177
|
+
) -> Tuple[List[int], List[int], List[int]]:
|
178
|
+
from cupyx.scipy.fft import next_fast_len
|
179
|
+
|
180
|
+
convolution_shape = [int(x + y - 1) for x, y in zip(arr1_shape, arr2_shape)]
|
181
|
+
fast_shape = [next_fast_len(x, real=True) for x in convolution_shape]
|
182
|
+
fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
|
183
|
+
|
184
|
+
return convolution_shape, fast_shape, fast_ft_shape
|
185
|
+
|
186
|
+
def max_filter_coordinates(self, score_space, min_distance: Tuple[int]):
|
187
|
+
score_box = tuple(min_distance for _ in range(score_space.ndim))
|
188
|
+
max_filter = self.maximum_filter(score_space, size=score_box, mode="constant")
|
189
|
+
max_filter = max_filter == score_space
|
190
|
+
|
191
|
+
peaks = self._array_backend.array(self._array_backend.nonzero(max_filter)).T
|
192
|
+
return peaks
|
193
|
+
|
194
|
+
# The default methods in Cupy were oddly slow
|
195
|
+
def var(self, a, *args, **kwargs):
|
196
|
+
out = a - self._array_backend.mean(a, *args, **kwargs)
|
197
|
+
self._array_backend.square(out, out)
|
198
|
+
out = self._array_backend.mean(out, *args, **kwargs)
|
199
|
+
return out
|
200
|
+
|
201
|
+
def std(self, a, *args, **kwargs):
|
202
|
+
out = self.var(a, *args, **kwargs)
|
203
|
+
return self._array_backend.sqrt(out)
|
204
|
+
|
205
|
+
def _get_texture(self, arr: CupyArray, order: int = 3, prefilter: bool = False):
|
206
|
+
key = id(arr)
|
207
|
+
if key in TEXTURE_CACHE:
|
208
|
+
return TEXTURE_CACHE[key]
|
209
|
+
|
210
|
+
from voltools import StaticVolume
|
211
|
+
|
212
|
+
# Only keep template and potential corresponding mask in cache
|
213
|
+
if len(TEXTURE_CACHE) >= 2:
|
214
|
+
TEXTURE_CACHE.clear()
|
215
|
+
|
216
|
+
interpolation = "filt_bspline"
|
217
|
+
if order == 1:
|
218
|
+
interpolation = "linear"
|
219
|
+
elif order == 3 and not prefilter:
|
220
|
+
interpolation = "bspline"
|
221
|
+
|
222
|
+
current_device = self._array_backend.cuda.device.get_device_id()
|
223
|
+
TEXTURE_CACHE[key] = StaticVolume(
|
224
|
+
arr, interpolation=interpolation, device=f"gpu:{current_device}"
|
225
|
+
)
|
226
|
+
|
227
|
+
return TEXTURE_CACHE[key]
|
228
|
+
|
229
|
+
def _rigid_transform(
|
230
|
+
self,
|
231
|
+
data: CupyArray,
|
232
|
+
matrix: CupyArray,
|
233
|
+
output: CupyArray,
|
234
|
+
prefilter: bool,
|
235
|
+
order: int,
|
236
|
+
cache: bool = False,
|
237
|
+
batched: bool = False,
|
238
|
+
) -> None:
|
239
|
+
out_slice = tuple(slice(0, stop) for stop in data.shape)
|
240
|
+
if batched:
|
241
|
+
self.affine_transform_batch(
|
242
|
+
input=data,
|
243
|
+
matrix=matrix,
|
244
|
+
mode="constant",
|
245
|
+
output=output[out_slice],
|
246
|
+
order=order,
|
247
|
+
prefilter=prefilter,
|
248
|
+
)
|
249
|
+
return None
|
250
|
+
|
251
|
+
# if data.ndim == 3 and cache and self.texture_available:
|
252
|
+
# # Device memory pool (should) come to rescue performance
|
253
|
+
# temp = self.zeros(data.shape, data.dtype)
|
254
|
+
# texture = self._get_texture(data, order=order, prefilter=prefilter)
|
255
|
+
# texture.affine(transform_m=matrix, profile=False, output=temp)
|
256
|
+
# output[out_slice] = temp
|
257
|
+
# return None
|
258
|
+
|
259
|
+
self.affine_transform(
|
260
|
+
input=data,
|
261
|
+
matrix=matrix,
|
262
|
+
mode="constant",
|
263
|
+
output=output[out_slice],
|
264
|
+
order=order,
|
265
|
+
prefilter=prefilter,
|
266
|
+
)
|
267
|
+
|
268
|
+
def get_available_memory(self) -> int:
|
269
|
+
with self._array_backend.cuda.Device():
|
270
|
+
free_memory, _ = self._array_backend.cuda.runtime.memGetInfo()
|
271
|
+
return free_memory
|
272
|
+
|
273
|
+
@contextmanager
|
274
|
+
def set_device(self, device_index: int):
|
275
|
+
with self._array_backend.cuda.Device(device_index):
|
276
|
+
yield
|
277
|
+
|
278
|
+
def device_count(self) -> int:
|
279
|
+
return self._array_backend.cuda.runtime.getDeviceCount()
|
280
|
+
|
281
|
+
def max_score_over_rotations(
|
282
|
+
self,
|
283
|
+
scores: CupyArray,
|
284
|
+
max_scores: CupyArray,
|
285
|
+
rotations: CupyArray,
|
286
|
+
rotation_index: int,
|
287
|
+
) -> Tuple[CupyArray, CupyArray]:
|
288
|
+
return self._max_score_over_rotations(
|
289
|
+
max_scores,
|
290
|
+
scores,
|
291
|
+
rotation_index,
|
292
|
+
max_scores,
|
293
|
+
rotations,
|
294
|
+
)
|