pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
- pytme-0.2.9.data/scripts/match_template.py +1135 -0
- pytme-0.2.9.data/scripts/postprocess.py +622 -0
- pytme-0.2.9.data/scripts/preprocess.py +209 -0
- pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
- pytme-0.2.9.dist-info/METADATA +95 -0
- pytme-0.2.9.dist-info/RECORD +119 -0
- pytme-0.2.9.dist-info/WHEEL +5 -0
- pytme-0.2.9.dist-info/entry_points.txt +6 -0
- pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
- pytme-0.2.9.dist-info/top_level.txt +3 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +1135 -0
- scripts/postprocess.py +622 -0
- scripts/preprocess.py +209 -0
- scripts/preprocessor_gui.py +1227 -0
- tests/__init__.py +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +216 -0
- tests/test_backends.py +446 -0
- tests/test_density.py +503 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +124 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +226 -0
- tests/test_matching_utils.py +189 -0
- tests/test_orientations.py +175 -0
- tests/test_parser.py +33 -0
- tests/test_rotations.py +153 -0
- tests/test_structure.py +247 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +186 -0
- tme/analyzer/aggregation.py +577 -0
- tme/analyzer/peaks.py +953 -0
- tme/backends/__init__.py +171 -0
- tme/backends/_cupy_utils.py +734 -0
- tme/backends/_jax_utils.py +188 -0
- tme/backends/cupy_backend.py +294 -0
- tme/backends/jax_backend.py +314 -0
- tme/backends/matching_backend.py +1270 -0
- tme/backends/mlx_backend.py +241 -0
- tme/backends/npfftw_backend.py +583 -0
- tme/backends/pytorch_backend.py +430 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2263 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/filters/__init__.py +6 -0
- tme/filters/_utils.py +311 -0
- tme/filters/bandpass.py +230 -0
- tme/filters/compose.py +81 -0
- tme/filters/ctf.py +393 -0
- tme/filters/reconstruction.py +160 -0
- tme/filters/wedge.py +542 -0
- tme/filters/whitening.py +191 -0
- tme/matching_data.py +863 -0
- tme/matching_exhaustive.py +497 -0
- tme/matching_optimization.py +1311 -0
- tme/matching_scores.py +1183 -0
- tme/matching_utils.py +1188 -0
- tme/memory.py +337 -0
- tme/orientations.py +598 -0
- tme/parser.py +685 -0
- tme/preprocessor.py +1329 -0
- tme/rotations.py +350 -0
- tme/structure.py +1864 -0
- tme/types.py +13 -0
@@ -0,0 +1,583 @@
|
|
1
|
+
""" Backend using numpy and pyFFTW for template matching.
|
2
|
+
|
3
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
import os
|
9
|
+
from psutil import virtual_memory
|
10
|
+
from contextlib import contextmanager
|
11
|
+
from typing import Tuple, Dict, List, Type
|
12
|
+
|
13
|
+
import scipy
|
14
|
+
import numpy as np
|
15
|
+
from scipy.ndimage import maximum_filter, affine_transform
|
16
|
+
from pyfftw.builders import rfftn as rfftn_builder, irfftn as irfftn_builder
|
17
|
+
from pyfftw import zeros_aligned, simd_alignment, FFTW, next_fast_len, interfaces
|
18
|
+
|
19
|
+
from ..types import NDArray, BackendArray, shm_type
|
20
|
+
from .matching_backend import MatchingBackend, _create_metafunction
|
21
|
+
|
22
|
+
os.environ["MKL_NUM_THREADS"] = "1"
|
23
|
+
os.environ["OMP_NUM_THREADS"] = "1"
|
24
|
+
os.environ["PYFFTW_NUM_THREADS"] = "1"
|
25
|
+
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
26
|
+
|
27
|
+
|
28
|
+
def create_ufuncs(obj):
|
29
|
+
ufuncs = [
|
30
|
+
"add",
|
31
|
+
"subtract",
|
32
|
+
"multiply",
|
33
|
+
"divide",
|
34
|
+
"mod",
|
35
|
+
"sum",
|
36
|
+
"where",
|
37
|
+
"einsum",
|
38
|
+
"mean",
|
39
|
+
"einsum",
|
40
|
+
"std",
|
41
|
+
"max",
|
42
|
+
"min",
|
43
|
+
"maximum",
|
44
|
+
"minimum",
|
45
|
+
"sqrt",
|
46
|
+
"square",
|
47
|
+
"abs",
|
48
|
+
"power",
|
49
|
+
"full",
|
50
|
+
"clip",
|
51
|
+
"arange",
|
52
|
+
"stack",
|
53
|
+
"concatenate",
|
54
|
+
"repeat",
|
55
|
+
"indices",
|
56
|
+
"unique",
|
57
|
+
"argsort",
|
58
|
+
"tril_indices",
|
59
|
+
"reshape",
|
60
|
+
"identity",
|
61
|
+
"dot",
|
62
|
+
"copy",
|
63
|
+
]
|
64
|
+
for ufunc in ufuncs:
|
65
|
+
setattr(obj, ufunc, _create_metafunction(ufunc))
|
66
|
+
return obj
|
67
|
+
|
68
|
+
|
69
|
+
@create_ufuncs
|
70
|
+
class _NumpyWrapper:
|
71
|
+
"""
|
72
|
+
MatchingBackend prohibits using create_ufuncs on NumpyFFTWBackend directly.
|
73
|
+
"""
|
74
|
+
|
75
|
+
pass
|
76
|
+
|
77
|
+
|
78
|
+
class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
79
|
+
"""
|
80
|
+
A numpy and pyfftw-based matching backend.
|
81
|
+
"""
|
82
|
+
|
83
|
+
def __init__(
|
84
|
+
self,
|
85
|
+
array_backend=np,
|
86
|
+
float_dtype=np.float32,
|
87
|
+
complex_dtype=np.complex64,
|
88
|
+
int_dtype=np.int32,
|
89
|
+
overflow_safe_dtype=np.float32,
|
90
|
+
**kwargs,
|
91
|
+
):
|
92
|
+
super().__init__(
|
93
|
+
array_backend=array_backend,
|
94
|
+
float_dtype=float_dtype,
|
95
|
+
complex_dtype=complex_dtype,
|
96
|
+
int_dtype=int_dtype,
|
97
|
+
overflow_safe_dtype=overflow_safe_dtype,
|
98
|
+
)
|
99
|
+
self.affine_transform = affine_transform
|
100
|
+
|
101
|
+
self.cholesky = self._linalg_cholesky
|
102
|
+
self.solve_triangular = self._solve_triangular
|
103
|
+
self.linalg.solve_triangular = scipy.linalg.solve_triangular
|
104
|
+
|
105
|
+
def _linalg_cholesky(self, arr, lower=False, *args, **kwargs):
|
106
|
+
# Upper argument is not supported until numpy 2.0
|
107
|
+
ret = self._array_backend.linalg.cholesky(arr, *args, **kwargs)
|
108
|
+
if not lower:
|
109
|
+
axes = list(range(ret.ndim))
|
110
|
+
axes[-2:] = (ret.ndim - 1, ret.ndim - 2)
|
111
|
+
ret = self._array_backend.transpose(ret, axes)
|
112
|
+
return ret
|
113
|
+
|
114
|
+
def _solve_triangular(self, a, b, lower=True, *args, **kwargs):
|
115
|
+
mask = self._array_backend.tril if lower else self._array_backend.triu
|
116
|
+
return self._array_backend.linalg.solve(mask(a), b, *args, **kwargs)
|
117
|
+
|
118
|
+
def to_backend_array(self, arr: NDArray) -> NDArray:
|
119
|
+
if isinstance(arr, self._array_backend.ndarray):
|
120
|
+
return arr
|
121
|
+
return self._array_backend.asarray(arr)
|
122
|
+
|
123
|
+
def to_numpy_array(self, arr: NDArray) -> NDArray:
|
124
|
+
return np.array(arr)
|
125
|
+
|
126
|
+
def to_cpu_array(self, arr: NDArray) -> NDArray:
|
127
|
+
return arr
|
128
|
+
|
129
|
+
def get_fundamental_dtype(self, arr: NDArray) -> Type:
|
130
|
+
dt = arr.dtype
|
131
|
+
if self._array_backend.issubdtype(dt, self._array_backend.integer):
|
132
|
+
return int
|
133
|
+
elif self._array_backend.issubdtype(dt, self._array_backend.floating):
|
134
|
+
return float
|
135
|
+
elif self._array_backend.issubdtype(dt, self._array_backend.complexfloating):
|
136
|
+
return complex
|
137
|
+
return float
|
138
|
+
|
139
|
+
def free_cache(self):
|
140
|
+
pass
|
141
|
+
|
142
|
+
def transpose(self, arr: NDArray, *args, **kwargs) -> NDArray:
|
143
|
+
return self._array_backend.transpose(arr, *args, **kwargs)
|
144
|
+
|
145
|
+
def tobytes(self, arr: NDArray) -> str:
|
146
|
+
return arr.tobytes()
|
147
|
+
|
148
|
+
def size(self, arr: NDArray) -> int:
|
149
|
+
return arr.size
|
150
|
+
|
151
|
+
def fill(self, arr: NDArray, value: float) -> NDArray:
|
152
|
+
arr.fill(value)
|
153
|
+
return arr
|
154
|
+
|
155
|
+
def eps(self, dtype: type) -> NDArray:
|
156
|
+
return self._array_backend.finfo(dtype).eps
|
157
|
+
|
158
|
+
def datatype_bytes(self, dtype: type) -> NDArray:
|
159
|
+
temp = self._array_backend.zeros(1, dtype=dtype)
|
160
|
+
return temp.nbytes
|
161
|
+
|
162
|
+
@staticmethod
|
163
|
+
def astype(arr, dtype: Type) -> NDArray:
|
164
|
+
return arr.astype(dtype)
|
165
|
+
|
166
|
+
@staticmethod
|
167
|
+
def at(arr, idx, value) -> NDArray:
|
168
|
+
arr[idx] = value
|
169
|
+
return arr
|
170
|
+
|
171
|
+
def addat(self, arr, indices, *args, **kwargs) -> NDArray:
|
172
|
+
self._array_backend.add.at(arr, indices, *args, **kwargs)
|
173
|
+
return arr
|
174
|
+
|
175
|
+
def topk_indices(self, arr: NDArray, k: int):
|
176
|
+
temp = arr.reshape(-1)
|
177
|
+
indices = self._array_backend.argpartition(temp, -k)[-k:][:k]
|
178
|
+
sorted_indices = indices[self._array_backend.argsort(temp[indices])][::-1]
|
179
|
+
sorted_indices = self.unravel_index(indices=sorted_indices, shape=arr.shape)
|
180
|
+
return sorted_indices
|
181
|
+
|
182
|
+
def indices(self, *args, **kwargs) -> NDArray:
|
183
|
+
return self._array_backend.indices(*args, **kwargs)
|
184
|
+
|
185
|
+
def roll(
|
186
|
+
self, a: NDArray, shift: Tuple[int], axis: Tuple[int], **kwargs
|
187
|
+
) -> NDArray:
|
188
|
+
return self._array_backend.roll(
|
189
|
+
a,
|
190
|
+
shift=shift,
|
191
|
+
axis=axis,
|
192
|
+
**kwargs,
|
193
|
+
)
|
194
|
+
|
195
|
+
def unravel_index(self, indices: NDArray, shape: Tuple[int]) -> NDArray:
|
196
|
+
return self._array_backend.unravel_index(indices=indices, shape=shape)
|
197
|
+
|
198
|
+
def max_filter_coordinates(self, score_space: NDArray, min_distance: Tuple[int]):
|
199
|
+
score_box = tuple(min_distance for _ in range(score_space.ndim))
|
200
|
+
max_filter = maximum_filter(score_space, size=score_box, mode="constant")
|
201
|
+
max_filter = max_filter == score_space
|
202
|
+
|
203
|
+
peaks = np.array(np.nonzero(max_filter)).T
|
204
|
+
return peaks
|
205
|
+
|
206
|
+
@staticmethod
|
207
|
+
def zeros(shape: Tuple[int], dtype: type = None) -> NDArray:
|
208
|
+
arr = zeros_aligned(shape, dtype=dtype, n=simd_alignment)
|
209
|
+
return arr
|
210
|
+
|
211
|
+
def from_sharedarr(self, args) -> NDArray:
|
212
|
+
if len(args) == 1:
|
213
|
+
return args[0]
|
214
|
+
shm, shape, dtype = args
|
215
|
+
return self.ndarray(shape, dtype, shm.buf)
|
216
|
+
|
217
|
+
def to_sharedarr(
|
218
|
+
self, arr: NDArray, shared_memory_handler: type = None
|
219
|
+
) -> shm_type:
|
220
|
+
if shared_memory_handler is None:
|
221
|
+
return (arr,)
|
222
|
+
|
223
|
+
shm = shared_memory_handler.SharedMemory(size=arr.nbytes)
|
224
|
+
np_array = self.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf)
|
225
|
+
np_array[:] = arr[:].copy()
|
226
|
+
return shm, arr.shape, arr.dtype
|
227
|
+
|
228
|
+
def topleft_pad(self, arr: NDArray, shape: Tuple[int], padval: int = 0) -> NDArray:
|
229
|
+
b = self.zeros(shape, arr.dtype)
|
230
|
+
self.add(b, padval, out=b)
|
231
|
+
aind = [slice(None, None)] * arr.ndim
|
232
|
+
bind = [slice(None, None)] * arr.ndim
|
233
|
+
for i in range(arr.ndim):
|
234
|
+
if arr.shape[i] > shape[i]:
|
235
|
+
aind[i] = slice(0, shape[i])
|
236
|
+
elif arr.shape[i] < shape[i]:
|
237
|
+
bind[i] = slice(0, arr.shape[i])
|
238
|
+
b[tuple(bind)] = arr[tuple(aind)]
|
239
|
+
return b
|
240
|
+
|
241
|
+
def build_fft(
|
242
|
+
self,
|
243
|
+
fwd_shape: Tuple[int],
|
244
|
+
inv_shape: Tuple[int],
|
245
|
+
real_dtype: type,
|
246
|
+
cmpl_dtype: type,
|
247
|
+
fftargs: Dict = {},
|
248
|
+
inv_output_shape: Tuple[int] = None,
|
249
|
+
temp_fwd: NDArray = None,
|
250
|
+
temp_inv: NDArray = None,
|
251
|
+
fwd_axes: Tuple[int] = None,
|
252
|
+
inv_axes: Tuple[int] = None,
|
253
|
+
) -> Tuple[FFTW, FFTW]:
|
254
|
+
if temp_fwd is None:
|
255
|
+
temp_fwd = (
|
256
|
+
self.zeros(fwd_shape, real_dtype) if temp_fwd is None else temp_fwd
|
257
|
+
)
|
258
|
+
if temp_inv is None:
|
259
|
+
temp_inv = (
|
260
|
+
self.zeros(inv_shape, cmpl_dtype) if temp_inv is None else temp_inv
|
261
|
+
)
|
262
|
+
|
263
|
+
default_values = {
|
264
|
+
"planner_effort": "FFTW_MEASURE",
|
265
|
+
"auto_align_input": False,
|
266
|
+
"auto_contiguous": False,
|
267
|
+
"avoid_copy": True,
|
268
|
+
"overwrite_input": True,
|
269
|
+
"threads": 1,
|
270
|
+
}
|
271
|
+
for key in default_values:
|
272
|
+
if key in fftargs:
|
273
|
+
continue
|
274
|
+
fftargs[key] = default_values[key]
|
275
|
+
|
276
|
+
rfft_shape = self._format_fft_shape(temp_fwd.shape, fwd_axes)
|
277
|
+
_rfftn = rfftn_builder(temp_fwd, s=rfft_shape, axes=fwd_axes, **fftargs)
|
278
|
+
overwrite_input = fftargs.pop("overwrite_input", None)
|
279
|
+
|
280
|
+
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
281
|
+
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
282
|
+
_irfftn = irfftn_builder(temp_inv, s=irfft_shape, axes=inv_axes, **fftargs)
|
283
|
+
|
284
|
+
def _rfftn_wrapper(arr, out, *args, **kwargs):
|
285
|
+
return _rfftn(arr, out)
|
286
|
+
|
287
|
+
def _irfftn_wrapper(arr, out, *args, **kwargs):
|
288
|
+
return _irfftn(arr, out)
|
289
|
+
|
290
|
+
fftargs["overwrite_input"] = overwrite_input
|
291
|
+
return _rfftn_wrapper, _irfftn_wrapper
|
292
|
+
|
293
|
+
@staticmethod
|
294
|
+
def _format_fft_shape(shape: Tuple[int], axes: Tuple[int] = None):
|
295
|
+
if axes is None:
|
296
|
+
return shape
|
297
|
+
axes = tuple(sorted(range(len(shape))[i] for i in axes))
|
298
|
+
return tuple(shape[i] for i in axes)
|
299
|
+
|
300
|
+
def rfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
|
301
|
+
return interfaces.numpy_fft.rfftn(arr, **kwargs)
|
302
|
+
|
303
|
+
def irfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
|
304
|
+
return interfaces.numpy_fft.irfftn(arr, **kwargs)
|
305
|
+
|
306
|
+
def extract_center(self, arr: NDArray, newshape: Tuple[int]) -> NDArray:
|
307
|
+
new_shape = self.to_backend_array(newshape)
|
308
|
+
current_shape = self.to_backend_array(arr.shape)
|
309
|
+
starts = self.subtract(current_shape, new_shape)
|
310
|
+
starts = self.astype(self.divide(starts, 2), self._int_dtype)
|
311
|
+
stops = self.astype(self.add(starts, new_shape), self._int_dtype)
|
312
|
+
box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
|
313
|
+
return arr[box]
|
314
|
+
|
315
|
+
def compute_convolution_shapes(
|
316
|
+
self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
|
317
|
+
) -> Tuple[List[int], List[int], List[int]]:
|
318
|
+
convolution_shape = [int(x + y - 1) for x, y in zip(arr1_shape, arr2_shape)]
|
319
|
+
fast_shape = [next_fast_len(x) for x in convolution_shape]
|
320
|
+
fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
|
321
|
+
|
322
|
+
return convolution_shape, fast_shape, fast_ft_shape
|
323
|
+
|
324
|
+
def _rigid_transform_matrix(
|
325
|
+
self,
|
326
|
+
rotation_matrix: NDArray,
|
327
|
+
translation: NDArray = None,
|
328
|
+
center: NDArray = None,
|
329
|
+
) -> NDArray:
|
330
|
+
ndim = rotation_matrix.shape[0]
|
331
|
+
matrix = self.identity(ndim + 1, dtype=self._float_dtype)
|
332
|
+
|
333
|
+
if translation is not None:
|
334
|
+
translation_matrix = self.identity(ndim + 1, dtype=self._float_dtype)
|
335
|
+
translation_matrix[:ndim, ndim] = -translation
|
336
|
+
self.dot(matrix, translation_matrix, out=matrix)
|
337
|
+
|
338
|
+
if center is not None:
|
339
|
+
center_matrix = self.identity(ndim + 1, dtype=self._float_dtype)
|
340
|
+
center_matrix[:ndim, ndim] = center
|
341
|
+
self.dot(matrix, center_matrix, out=matrix)
|
342
|
+
|
343
|
+
if rotation_matrix is not None:
|
344
|
+
rmat = self.identity(ndim + 1, dtype=self._float_dtype)
|
345
|
+
rmat[:ndim, :ndim] = self._array_backend.linalg.inv(rotation_matrix)
|
346
|
+
self.dot(matrix, rmat, out=matrix)
|
347
|
+
|
348
|
+
if center is not None:
|
349
|
+
center_matrix[:ndim, ndim] = -center_matrix[:ndim, ndim]
|
350
|
+
self.dot(matrix, center_matrix, out=matrix)
|
351
|
+
|
352
|
+
matrix /= matrix[ndim, ndim]
|
353
|
+
return matrix
|
354
|
+
|
355
|
+
def _rigid_transform(
|
356
|
+
self,
|
357
|
+
data: NDArray,
|
358
|
+
matrix: NDArray,
|
359
|
+
output: NDArray,
|
360
|
+
prefilter: bool,
|
361
|
+
order: int,
|
362
|
+
cache: bool = False,
|
363
|
+
batched=False,
|
364
|
+
) -> None:
|
365
|
+
if batched:
|
366
|
+
for i in range(data.shape[0]):
|
367
|
+
self._rigid_transform(
|
368
|
+
data=data[i],
|
369
|
+
matrix=matrix,
|
370
|
+
output=output[i],
|
371
|
+
prefilter=prefilter,
|
372
|
+
order=order,
|
373
|
+
cache=cache,
|
374
|
+
batched=False,
|
375
|
+
)
|
376
|
+
return None
|
377
|
+
|
378
|
+
out_slice = tuple(slice(0, stop) for stop in data.shape)
|
379
|
+
self.affine_transform(
|
380
|
+
input=data,
|
381
|
+
matrix=matrix,
|
382
|
+
mode="constant",
|
383
|
+
output=output[out_slice],
|
384
|
+
order=order,
|
385
|
+
prefilter=prefilter,
|
386
|
+
)
|
387
|
+
|
388
|
+
def rigid_transform(
|
389
|
+
self,
|
390
|
+
arr: NDArray,
|
391
|
+
rotation_matrix: NDArray,
|
392
|
+
arr_mask: NDArray = None,
|
393
|
+
translation: NDArray = None,
|
394
|
+
use_geometric_center: bool = False,
|
395
|
+
out: NDArray = None,
|
396
|
+
out_mask: NDArray = None,
|
397
|
+
order: int = 3,
|
398
|
+
cache: bool = False,
|
399
|
+
) -> Tuple[NDArray, NDArray]:
|
400
|
+
out = self.zeros_like(arr) if out is None else out
|
401
|
+
batched = arr.ndim != rotation_matrix.shape[0]
|
402
|
+
|
403
|
+
center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
|
404
|
+
if not use_geometric_center:
|
405
|
+
center = self.center_of_mass(arr, cutoff=0)
|
406
|
+
|
407
|
+
offset = int(arr.ndim - rotation_matrix.shape[0])
|
408
|
+
center = center[offset:]
|
409
|
+
translation = self.zeros(center.size) if translation is None else translation
|
410
|
+
matrix = self._rigid_transform_matrix(
|
411
|
+
rotation_matrix=rotation_matrix,
|
412
|
+
translation=translation,
|
413
|
+
center=center,
|
414
|
+
)
|
415
|
+
|
416
|
+
subset = tuple(slice(None) for _ in range(arr.ndim))
|
417
|
+
if offset > 1:
|
418
|
+
subset = tuple(
|
419
|
+
0 if i < (offset - 1) else slice(None) for i in range(arr.ndim)
|
420
|
+
)
|
421
|
+
|
422
|
+
self._rigid_transform(
|
423
|
+
data=arr[subset],
|
424
|
+
matrix=matrix,
|
425
|
+
output=out[subset],
|
426
|
+
order=order,
|
427
|
+
prefilter=True,
|
428
|
+
cache=cache,
|
429
|
+
batched=batched,
|
430
|
+
)
|
431
|
+
|
432
|
+
# Applying the prefilter leads to artifacts in the mask.
|
433
|
+
if arr_mask is not None:
|
434
|
+
out_mask = self.zeros_like(arr_mask) if out_mask is None else out_mask
|
435
|
+
self._rigid_transform(
|
436
|
+
data=arr_mask[subset],
|
437
|
+
matrix=matrix,
|
438
|
+
output=out_mask[subset],
|
439
|
+
order=order,
|
440
|
+
prefilter=False,
|
441
|
+
cache=cache,
|
442
|
+
batched=batched,
|
443
|
+
)
|
444
|
+
|
445
|
+
return out, out_mask
|
446
|
+
|
447
|
+
def center_of_mass(self, arr: BackendArray, cutoff: float = None) -> BackendArray:
|
448
|
+
"""
|
449
|
+
Computes the center of mass of a numpy ndarray instance using all available
|
450
|
+
elements. For template matching it typically makes sense to only input
|
451
|
+
positive densities.
|
452
|
+
|
453
|
+
Parameters
|
454
|
+
----------
|
455
|
+
arr : BackendArray
|
456
|
+
Array to compute the center of mass of.
|
457
|
+
cutoff : float, optional
|
458
|
+
Densities less than or equal to cutoff are nullified for center
|
459
|
+
of mass computation. By default considers all values.
|
460
|
+
|
461
|
+
Returns
|
462
|
+
-------
|
463
|
+
BackendArray
|
464
|
+
Center of mass with shape (arr.ndim).
|
465
|
+
"""
|
466
|
+
cutoff = self.min(arr) - 1 if cutoff is None else cutoff
|
467
|
+
|
468
|
+
arr = self.where(arr > cutoff, arr, 0)
|
469
|
+
denominator = self.sum(arr)
|
470
|
+
|
471
|
+
grids = []
|
472
|
+
for i, x in enumerate(arr.shape):
|
473
|
+
baseline_dims = tuple(1 if i != t else x for t in range(len(arr.shape)))
|
474
|
+
grids.append(
|
475
|
+
self.reshape(self.arange(x, dtype=self._float_dtype), baseline_dims)
|
476
|
+
)
|
477
|
+
|
478
|
+
center_of_mass = [self.sum((arr * grid) / denominator) for grid in grids]
|
479
|
+
|
480
|
+
return self.to_backend_array(center_of_mass)
|
481
|
+
|
482
|
+
def get_available_memory(self) -> int:
|
483
|
+
return virtual_memory().available
|
484
|
+
|
485
|
+
@contextmanager
|
486
|
+
def set_device(self, device_index: int):
|
487
|
+
yield None
|
488
|
+
|
489
|
+
def device_count(self) -> int:
|
490
|
+
return 1
|
491
|
+
|
492
|
+
@staticmethod
|
493
|
+
def reverse(arr: NDArray, axis: Tuple[int] = None) -> NDArray:
|
494
|
+
if axis is None:
|
495
|
+
axis = tuple(range(arr.ndim))
|
496
|
+
keep, rev = slice(None, None), slice(None, None, -1)
|
497
|
+
return arr[tuple(rev if i in axis else keep for i in range(arr.ndim))]
|
498
|
+
|
499
|
+
def max_score_over_rotations(
|
500
|
+
self,
|
501
|
+
scores: BackendArray,
|
502
|
+
max_scores: BackendArray,
|
503
|
+
rotations: BackendArray,
|
504
|
+
rotation_index: int,
|
505
|
+
) -> None:
|
506
|
+
"""
|
507
|
+
Update elements in ``max_scores`` and ``rotations`` where scores is larger than
|
508
|
+
max_scores with score and rotation_index, respectivelty.
|
509
|
+
|
510
|
+
.. warning:: ``max_scores`` and ``rotations`` are modified in-place.
|
511
|
+
|
512
|
+
Parameters
|
513
|
+
----------
|
514
|
+
scores : BackendArray
|
515
|
+
The score space to compare against max_scores.
|
516
|
+
max_scores : BackendArray
|
517
|
+
Maximum score observed for each element in an array.
|
518
|
+
rotations : BackendArray
|
519
|
+
Rotation used to achieve a given max_score.
|
520
|
+
rotation_index : int
|
521
|
+
The index representing the current rotation.
|
522
|
+
|
523
|
+
Returns
|
524
|
+
-------
|
525
|
+
Tuple[BackendArray, BackendArray]
|
526
|
+
Updated ``max_scores`` and ``rotations``.
|
527
|
+
"""
|
528
|
+
indices = scores > max_scores
|
529
|
+
max_scores[indices] = scores[indices]
|
530
|
+
rotations[indices] = rotation_index
|
531
|
+
return max_scores, rotations
|
532
|
+
|
533
|
+
def norm_scores(
|
534
|
+
self,
|
535
|
+
arr: BackendArray,
|
536
|
+
exp_sq: BackendArray,
|
537
|
+
sq_exp: BackendArray,
|
538
|
+
n_obs: int,
|
539
|
+
eps: float,
|
540
|
+
out: BackendArray,
|
541
|
+
) -> BackendArray:
|
542
|
+
"""
|
543
|
+
Normalizes ``arr`` by the standard deviation ensuring numerical stability.
|
544
|
+
|
545
|
+
Parameters
|
546
|
+
----------
|
547
|
+
arr : BackendArray
|
548
|
+
The input array to be normalized.
|
549
|
+
exp_sq : BackendArray
|
550
|
+
Non-normalized expectation square.
|
551
|
+
sq_exp : BackendArray
|
552
|
+
Non-normalized expectation.
|
553
|
+
n_obs : int
|
554
|
+
Number of observations for normalization.
|
555
|
+
eps : float
|
556
|
+
Numbers below this threshold will be ignored in division.
|
557
|
+
out : BackendArray
|
558
|
+
Output array to write the result to.
|
559
|
+
|
560
|
+
Returns
|
561
|
+
-------
|
562
|
+
BackendArray
|
563
|
+
The normalized array with the same shape as `arr`.
|
564
|
+
|
565
|
+
See Also
|
566
|
+
--------
|
567
|
+
:py:meth:`tme.matching_exhaustive.flc_scoring`
|
568
|
+
"""
|
569
|
+
# Squared expected value (E(X)^2)
|
570
|
+
sq_exp = self.divide(sq_exp, n_obs, out=sq_exp)
|
571
|
+
sq_exp = self.square(sq_exp, out=sq_exp)
|
572
|
+
# Expected squared value (E(X^2))
|
573
|
+
exp_sq = self.divide(exp_sq, n_obs, out=exp_sq)
|
574
|
+
# Variance
|
575
|
+
sq_exp = self.subtract(exp_sq, sq_exp, out=sq_exp)
|
576
|
+
sq_exp = self.maximum(sq_exp, 0.0, out=sq_exp)
|
577
|
+
sq_exp = self.sqrt(sq_exp, out=sq_exp)
|
578
|
+
|
579
|
+
# Assume that low stdev regions also have low scores
|
580
|
+
# See :py:meth:`tme.matching_exhaustive.flcSphericalMask_setup` for correct norm
|
581
|
+
sq_exp[sq_exp < eps] = 1
|
582
|
+
sq_exp = self.multiply(sq_exp, n_obs, out=sq_exp)
|
583
|
+
return self.divide(arr, sq_exp, out=out)
|