pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
@@ -0,0 +1,188 @@
1
+ """ Utility functions for jax backend.
2
+
3
+ Copyright (c) 2023-2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ from typing import Tuple
9
+ from functools import partial
10
+
11
+ import jax.numpy as jnp
12
+ from jax import pmap, lax
13
+
14
+ from ..types import BackendArray
15
+ from ..backends import backend as be
16
+ from ..matching_utils import normalize_template as _normalize_template
17
+
18
+
19
+ def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
20
+ """
21
+ Computes :py:meth:`tme.matching_exhaustive.cc_setup`.
22
+ """
23
+ template_ft = jnp.fft.rfftn(template, s=template.shape)
24
+ template_ft = template_ft.at[:].multiply(ft_target)
25
+ correlation = jnp.fft.irfftn(template_ft, s=template.shape)
26
+ return correlation
27
+
28
+
29
+ def _flc_scoring(
30
+ template: BackendArray,
31
+ template_mask: BackendArray,
32
+ ft_target: BackendArray,
33
+ ft_target2: BackendArray,
34
+ n_observations: BackendArray,
35
+ eps: float,
36
+ **kwargs,
37
+ ) -> BackendArray:
38
+ """
39
+ Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
40
+ """
41
+ correlation = _correlate(template=template, ft_target=ft_target)
42
+ inv_denominator = _reciprocal_target_std(
43
+ ft_target=ft_target,
44
+ ft_target2=ft_target2,
45
+ template_mask=template_mask,
46
+ eps=eps,
47
+ n_observations=n_observations,
48
+ )
49
+ correlation = correlation.at[:].multiply(inv_denominator)
50
+ return correlation
51
+
52
+
53
+ def _flcSphere_scoring(
54
+ template: BackendArray,
55
+ ft_target: BackendArray,
56
+ inv_denominator: BackendArray,
57
+ **kwargs,
58
+ ) -> BackendArray:
59
+ """
60
+ Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
61
+ """
62
+ correlation = _correlate(template=template, ft_target=ft_target)
63
+ correlation = correlation.at[:].multiply(inv_denominator)
64
+ return correlation
65
+
66
+
67
+ def _reciprocal_target_std(
68
+ ft_target: BackendArray,
69
+ ft_target2: BackendArray,
70
+ template_mask: BackendArray,
71
+ n_observations: float,
72
+ eps: float,
73
+ ) -> BackendArray:
74
+ """
75
+ Computes reciprocal standard deviation of a target given a mask.
76
+
77
+ See Also
78
+ --------
79
+ :py:meth:`tme.matching_exhaustive.flc_scoring`.
80
+ """
81
+ ft_shape = template_mask.shape
82
+ ft_template_mask = jnp.fft.rfftn(template_mask, s=ft_shape)
83
+
84
+ # E(X^2)- E(X)^2
85
+ exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask, s=ft_shape)
86
+ exp_sq = exp_sq.at[:].divide(n_observations)
87
+
88
+ ft_template_mask = ft_template_mask.at[:].multiply(ft_target)
89
+ sq_exp = jnp.fft.irfftn(ft_template_mask, s=ft_shape)
90
+ sq_exp = sq_exp.at[:].divide(n_observations)
91
+ sq_exp = sq_exp.at[:].power(2)
92
+
93
+ exp_sq = exp_sq.at[:].add(-sq_exp)
94
+ exp_sq = exp_sq.at[:].max(0)
95
+ exp_sq = exp_sq.at[:].power(0.5)
96
+
97
+ exp_sq = exp_sq.at[:].set(
98
+ jnp.where(exp_sq <= eps, 0, jnp.reciprocal(exp_sq * n_observations))
99
+ )
100
+ return exp_sq
101
+
102
+
103
+ def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
104
+ arr_ft = jnp.fft.rfftn(arr, s=arr.shape)
105
+ arr_ft = arr_ft.at[:].multiply(arr_filter)
106
+ return arr.at[:].set(jnp.fft.irfftn(arr_ft, s=arr.shape))
107
+
108
+
109
+ def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
110
+ return arr
111
+
112
+
113
+ @partial(
114
+ pmap,
115
+ in_axes=(0,) + (None,) * 6,
116
+ static_broadcasted_argnums=[6, 7],
117
+ )
118
+ def scan(
119
+ target: BackendArray,
120
+ template: BackendArray,
121
+ template_mask: BackendArray,
122
+ rotations: BackendArray,
123
+ template_filter: BackendArray,
124
+ target_filter: BackendArray,
125
+ fast_shape: Tuple[int],
126
+ rotate_mask: bool,
127
+ ) -> Tuple[BackendArray, BackendArray]:
128
+ eps = jnp.finfo(template.dtype).resolution
129
+
130
+ if hasattr(target_filter, "shape"):
131
+ target = _apply_fourier_filter(target, target_filter)
132
+
133
+ ft_target = jnp.fft.rfftn(target, s=fast_shape)
134
+ ft_target2 = jnp.fft.rfftn(jnp.square(target), s=fast_shape)
135
+ inv_denominator, target, scoring_func = None, None, _flc_scoring
136
+ if not rotate_mask:
137
+ n_observations = jnp.sum(template_mask)
138
+ inv_denominator = _reciprocal_target_std(
139
+ ft_target=ft_target,
140
+ ft_target2=ft_target2,
141
+ template_mask=be.topleft_pad(template_mask, fast_shape),
142
+ eps=eps,
143
+ n_observations=n_observations,
144
+ )
145
+ ft_target2, scoring_func = None, _flcSphere_scoring
146
+
147
+ _template_filter_func = _identity
148
+ if template_filter.shape != ():
149
+ _template_filter_func = _apply_fourier_filter
150
+
151
+ def _sample_transform(ret, rotation_matrix):
152
+ max_scores, rotations, index = ret
153
+ template_rot, template_mask_rot = be.rigid_transform(
154
+ arr=template,
155
+ arr_mask=template_mask,
156
+ rotation_matrix=rotation_matrix,
157
+ order=1, # thats all we get for now
158
+ )
159
+
160
+ n_observations = jnp.sum(template_mask_rot)
161
+ template_rot = _template_filter_func(template_rot, template_filter)
162
+ template_rot = _normalize_template(
163
+ template_rot, template_mask_rot, n_observations
164
+ )
165
+ template_rot = be.topleft_pad(template_rot, fast_shape)
166
+ template_mask_rot = be.topleft_pad(template_mask_rot, fast_shape)
167
+
168
+ scores = scoring_func(
169
+ template=template_rot,
170
+ template_mask=template_mask_rot,
171
+ ft_target=ft_target,
172
+ ft_target2=ft_target2,
173
+ inv_denominator=inv_denominator,
174
+ n_observations=n_observations,
175
+ eps=eps,
176
+ )
177
+ max_scores, rotations = be.max_score_over_rotations(
178
+ scores, max_scores, rotations, index
179
+ )
180
+ return (max_scores, rotations, index + 1), None
181
+
182
+ score_space = jnp.zeros(fast_shape)
183
+ rotation_space = jnp.full(shape=fast_shape, dtype=jnp.int32, fill_value=-1)
184
+ (score_space, rotation_space, _), _ = lax.scan(
185
+ _sample_transform, (score_space, rotation_space, 0), rotations
186
+ )
187
+
188
+ return score_space, rotation_space
@@ -0,0 +1,294 @@
1
+ """ Backend using cupy for template matching.
2
+
3
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ import warnings
9
+ from importlib.util import find_spec
10
+ from contextlib import contextmanager
11
+ from typing import Tuple, Callable, List
12
+
13
+ import numpy as np
14
+
15
+ from .npfftw_backend import NumpyFFTWBackend
16
+ from ..types import CupyArray, NDArray, shm_type
17
+
18
+ PLAN_CACHE = {}
19
+ TEXTURE_CACHE = {}
20
+
21
+
22
+ class CupyBackend(NumpyFFTWBackend):
23
+ """
24
+ A cupy-based matching backend.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ float_dtype: type = None,
30
+ complex_dtype: type = None,
31
+ int_dtype: type = None,
32
+ overflow_safe_dtype: type = None,
33
+ **kwargs,
34
+ ):
35
+ import cupy as cp
36
+ import cupyx.scipy.fft as cufft
37
+ from cupyx.scipy.ndimage import affine_transform, maximum_filter
38
+ from ._cupy_utils import affine_transform_batch
39
+
40
+ float_dtype = cp.float32 if float_dtype is None else float_dtype
41
+ complex_dtype = cp.complex64 if complex_dtype is None else complex_dtype
42
+ int_dtype = cp.int32 if int_dtype is None else int_dtype
43
+ if overflow_safe_dtype is None:
44
+ overflow_safe_dtype = cp.float32
45
+
46
+ super().__init__(
47
+ array_backend=cp,
48
+ float_dtype=float_dtype,
49
+ complex_dtype=complex_dtype,
50
+ int_dtype=int_dtype,
51
+ overflow_safe_dtype=overflow_safe_dtype,
52
+ )
53
+ self._cufft = cufft
54
+ self.maximum_filter = maximum_filter
55
+ self.affine_transform = affine_transform
56
+ self.affine_transform_batch = affine_transform_batch
57
+
58
+ itype = f"int{self.datatype_bytes(int_dtype) * 8}"
59
+ ftype = f"float{self.datatype_bytes(float_dtype) * 8}"
60
+ self._max_score_over_rotations = self._array_backend.ElementwiseKernel(
61
+ f"{ftype} internal_scores, {ftype} scores, {itype} rot_index",
62
+ f"{ftype} out1, {itype} rotations",
63
+ "if (internal_scores < scores) {out1 = scores; rotations = rot_index;}",
64
+ "max_score_over_rotations",
65
+ )
66
+ self.norm_scores = cp.ElementwiseKernel(
67
+ f"{ftype} arr, {ftype} exp_sq, {ftype} sq_exp, {ftype} n_obs, {ftype} eps",
68
+ f"{ftype} out",
69
+ """
70
+ // tmp1 = E(X)^2; tmp2 = E(X^2)
71
+ float tmp1 = sq_exp / n_obs;
72
+ float tmp2 = exp_sq / n_obs;
73
+ tmp1 *= tmp1;
74
+
75
+ tmp2 = sqrt(max(tmp2 - tmp1, 0.0));
76
+ // out = (tmp2 < eps) ? 0.0 : arr / (tmp2 * n_obs);
77
+ tmp1 = arr;
78
+ if (tmp2 < eps){
79
+ tmp1 = 0;
80
+ }
81
+ tmp2 *= n_obs;
82
+ out = tmp1 / tmp2;
83
+ """,
84
+ "norm_scores",
85
+ )
86
+ self.texture_available = find_spec("voltools") is not None
87
+
88
+ def to_backend_array(self, arr: NDArray) -> CupyArray:
89
+ current_device = self._array_backend.cuda.device.get_device_id()
90
+ if (
91
+ isinstance(arr, self._array_backend.ndarray)
92
+ and arr.device.id == current_device
93
+ ):
94
+ return arr
95
+ return self._array_backend.asarray(arr)
96
+
97
+ def to_numpy_array(self, arr: CupyArray) -> NDArray:
98
+ return self._array_backend.asnumpy(arr)
99
+
100
+ def to_cpu_array(self, arr: NDArray) -> NDArray:
101
+ return self.to_numpy_array(arr)
102
+
103
+ def from_sharedarr(self, arr: CupyArray) -> CupyArray:
104
+ return arr
105
+
106
+ @staticmethod
107
+ def to_sharedarr(arr: CupyArray, shared_memory_handler: type = None) -> shm_type:
108
+ return arr
109
+
110
+ def zeros(self, *args, **kwargs):
111
+ return self._array_backend.zeros(*args, **kwargs)
112
+
113
+ def unravel_index(self, indices, shape):
114
+ return self._array_backend.unravel_index(indices=indices, dims=shape)
115
+
116
+ def unique(self, ar, axis=None, *args, **kwargs):
117
+ if axis is None:
118
+ return self._array_backend.unique(ar=ar, axis=axis, *args, **kwargs)
119
+
120
+ warnings.warn("Axis argument not yet supported in CupY, falling back to NumPy.")
121
+ ret = np.unique(ar=self.to_numpy_array(ar), axis=axis, *args, **kwargs)
122
+ if not isinstance(ret, tuple):
123
+ return self.to_backend_array(ret)
124
+ return tuple(self.to_backend_array(k) for k in ret)
125
+
126
+ def build_fft(
127
+ self,
128
+ fwd_shape: Tuple[int],
129
+ inv_shape: Tuple[int],
130
+ inv_output_shape: Tuple[int] = None,
131
+ fwd_axes: Tuple[int] = None,
132
+ inv_axes: Tuple[int] = None,
133
+ **kwargs,
134
+ ) -> Tuple[Callable, Callable]:
135
+ cache = self._array_backend.fft.config.get_plan_cache()
136
+ current_device = self._array_backend.cuda.device.get_device_id()
137
+
138
+ previous_transform = [fwd_shape, inv_shape]
139
+ if current_device in PLAN_CACHE:
140
+ previous_transform = PLAN_CACHE[current_device]
141
+
142
+ real_diff, cmplx_diff = True, True
143
+ if len(fwd_shape) == len(previous_transform[0]):
144
+ real_diff = fwd_shape == previous_transform[0]
145
+ if len(inv_shape) == len(previous_transform[1]):
146
+ cmplx_diff = inv_shape == previous_transform[1]
147
+
148
+ if real_diff or cmplx_diff:
149
+ cache.clear()
150
+
151
+ rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
152
+ irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
153
+ irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
154
+
155
+ def rfftn(
156
+ arr: CupyArray, out: CupyArray = None, s=rfft_shape, axes=fwd_axes
157
+ ) -> CupyArray:
158
+ return self.rfftn(arr, s=s, axes=fwd_axes)
159
+
160
+ def irfftn(
161
+ arr: CupyArray, out: CupyArray = None, s=irfft_shape, axes=inv_axes
162
+ ) -> CupyArray:
163
+ return self.irfftn(arr, s=s, axes=inv_axes)
164
+
165
+ PLAN_CACHE[current_device] = [fwd_shape, inv_shape]
166
+
167
+ return rfftn, irfftn
168
+
169
+ def rfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
170
+ return self._cufft.rfftn(arr, **kwargs)
171
+
172
+ def irfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
173
+ return self._cufft.irfftn(arr, **kwargs)
174
+
175
+ def compute_convolution_shapes(
176
+ self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
177
+ ) -> Tuple[List[int], List[int], List[int]]:
178
+ from cupyx.scipy.fft import next_fast_len
179
+
180
+ convolution_shape = [int(x + y - 1) for x, y in zip(arr1_shape, arr2_shape)]
181
+ fast_shape = [next_fast_len(x, real=True) for x in convolution_shape]
182
+ fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
183
+
184
+ return convolution_shape, fast_shape, fast_ft_shape
185
+
186
+ def max_filter_coordinates(self, score_space, min_distance: Tuple[int]):
187
+ score_box = tuple(min_distance for _ in range(score_space.ndim))
188
+ max_filter = self.maximum_filter(score_space, size=score_box, mode="constant")
189
+ max_filter = max_filter == score_space
190
+
191
+ peaks = self._array_backend.array(self._array_backend.nonzero(max_filter)).T
192
+ return peaks
193
+
194
+ # The default methods in Cupy were oddly slow
195
+ def var(self, a, *args, **kwargs):
196
+ out = a - self._array_backend.mean(a, *args, **kwargs)
197
+ self._array_backend.square(out, out)
198
+ out = self._array_backend.mean(out, *args, **kwargs)
199
+ return out
200
+
201
+ def std(self, a, *args, **kwargs):
202
+ out = self.var(a, *args, **kwargs)
203
+ return self._array_backend.sqrt(out)
204
+
205
+ def _get_texture(self, arr: CupyArray, order: int = 3, prefilter: bool = False):
206
+ key = id(arr)
207
+ if key in TEXTURE_CACHE:
208
+ return TEXTURE_CACHE[key]
209
+
210
+ from voltools import StaticVolume
211
+
212
+ # Only keep template and potential corresponding mask in cache
213
+ if len(TEXTURE_CACHE) >= 2:
214
+ TEXTURE_CACHE.clear()
215
+
216
+ interpolation = "filt_bspline"
217
+ if order == 1:
218
+ interpolation = "linear"
219
+ elif order == 3 and not prefilter:
220
+ interpolation = "bspline"
221
+
222
+ current_device = self._array_backend.cuda.device.get_device_id()
223
+ TEXTURE_CACHE[key] = StaticVolume(
224
+ arr, interpolation=interpolation, device=f"gpu:{current_device}"
225
+ )
226
+
227
+ return TEXTURE_CACHE[key]
228
+
229
+ def _rigid_transform(
230
+ self,
231
+ data: CupyArray,
232
+ matrix: CupyArray,
233
+ output: CupyArray,
234
+ prefilter: bool,
235
+ order: int,
236
+ cache: bool = False,
237
+ batched: bool = False,
238
+ ) -> None:
239
+ out_slice = tuple(slice(0, stop) for stop in data.shape)
240
+ if batched:
241
+ self.affine_transform_batch(
242
+ input=data,
243
+ matrix=matrix,
244
+ mode="constant",
245
+ output=output[out_slice],
246
+ order=order,
247
+ prefilter=prefilter,
248
+ )
249
+ return None
250
+
251
+ # if data.ndim == 3 and cache and self.texture_available:
252
+ # # Device memory pool (should) come to rescue performance
253
+ # temp = self.zeros(data.shape, data.dtype)
254
+ # texture = self._get_texture(data, order=order, prefilter=prefilter)
255
+ # texture.affine(transform_m=matrix, profile=False, output=temp)
256
+ # output[out_slice] = temp
257
+ # return None
258
+
259
+ self.affine_transform(
260
+ input=data,
261
+ matrix=matrix,
262
+ mode="constant",
263
+ output=output[out_slice],
264
+ order=order,
265
+ prefilter=prefilter,
266
+ )
267
+
268
+ def get_available_memory(self) -> int:
269
+ with self._array_backend.cuda.Device():
270
+ free_memory, _ = self._array_backend.cuda.runtime.memGetInfo()
271
+ return free_memory
272
+
273
+ @contextmanager
274
+ def set_device(self, device_index: int):
275
+ with self._array_backend.cuda.Device(device_index):
276
+ yield
277
+
278
+ def device_count(self) -> int:
279
+ return self._array_backend.cuda.runtime.getDeviceCount()
280
+
281
+ def max_score_over_rotations(
282
+ self,
283
+ scores: CupyArray,
284
+ max_scores: CupyArray,
285
+ rotations: CupyArray,
286
+ rotation_index: int,
287
+ ) -> Tuple[CupyArray, CupyArray]:
288
+ return self._max_score_over_rotations(
289
+ max_scores,
290
+ scores,
291
+ rotation_index,
292
+ max_scores,
293
+ rotations,
294
+ )