pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
- pytme-0.2.9.data/scripts/match_template.py +1135 -0
- pytme-0.2.9.data/scripts/postprocess.py +622 -0
- pytme-0.2.9.data/scripts/preprocess.py +209 -0
- pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
- pytme-0.2.9.dist-info/METADATA +95 -0
- pytme-0.2.9.dist-info/RECORD +119 -0
- pytme-0.2.9.dist-info/WHEEL +5 -0
- pytme-0.2.9.dist-info/entry_points.txt +6 -0
- pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
- pytme-0.2.9.dist-info/top_level.txt +3 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +1135 -0
- scripts/postprocess.py +622 -0
- scripts/preprocess.py +209 -0
- scripts/preprocessor_gui.py +1227 -0
- tests/__init__.py +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +216 -0
- tests/test_backends.py +446 -0
- tests/test_density.py +503 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +124 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +226 -0
- tests/test_matching_utils.py +189 -0
- tests/test_orientations.py +175 -0
- tests/test_parser.py +33 -0
- tests/test_rotations.py +153 -0
- tests/test_structure.py +247 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +186 -0
- tme/analyzer/aggregation.py +577 -0
- tme/analyzer/peaks.py +953 -0
- tme/backends/__init__.py +171 -0
- tme/backends/_cupy_utils.py +734 -0
- tme/backends/_jax_utils.py +188 -0
- tme/backends/cupy_backend.py +294 -0
- tme/backends/jax_backend.py +314 -0
- tme/backends/matching_backend.py +1270 -0
- tme/backends/mlx_backend.py +241 -0
- tme/backends/npfftw_backend.py +583 -0
- tme/backends/pytorch_backend.py +430 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2263 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/filters/__init__.py +6 -0
- tme/filters/_utils.py +311 -0
- tme/filters/bandpass.py +230 -0
- tme/filters/compose.py +81 -0
- tme/filters/ctf.py +393 -0
- tme/filters/reconstruction.py +160 -0
- tme/filters/wedge.py +542 -0
- tme/filters/whitening.py +191 -0
- tme/matching_data.py +863 -0
- tme/matching_exhaustive.py +497 -0
- tme/matching_optimization.py +1311 -0
- tme/matching_scores.py +1183 -0
- tme/matching_utils.py +1188 -0
- tme/memory.py +337 -0
- tme/orientations.py +598 -0
- tme/parser.py +685 -0
- tme/preprocessor.py +1329 -0
- tme/rotations.py +350 -0
- tme/structure.py +1864 -0
- tme/types.py +13 -0
tme/matching_data.py
ADDED
@@ -0,0 +1,863 @@
|
|
1
|
+
""" Class representation of template matching data.
|
2
|
+
|
3
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
import warnings
|
9
|
+
from typing import Tuple, List, Optional, Generator, Dict
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
|
13
|
+
from . import Density
|
14
|
+
from .filters import Compose
|
15
|
+
from .backends import backend as be
|
16
|
+
from .types import BackendArray, NDArray
|
17
|
+
from .matching_utils import compute_parallelization_schedule
|
18
|
+
|
19
|
+
__all__ = ["MatchingData"]
|
20
|
+
|
21
|
+
|
22
|
+
class MatchingData:
|
23
|
+
"""
|
24
|
+
Contains data required for template matching.
|
25
|
+
|
26
|
+
Parameters
|
27
|
+
----------
|
28
|
+
target : np.ndarray or :py:class:`tme.density.Density`
|
29
|
+
Target data.
|
30
|
+
template : np.ndarray or :py:class:`tme.density.Density`
|
31
|
+
Template data.
|
32
|
+
target_mask : np.ndarray or :py:class:`tme.density.Density`, optional
|
33
|
+
Target mask data.
|
34
|
+
template_mask : np.ndarray or :py:class:`tme.density.Density`, optional
|
35
|
+
Template mask data.
|
36
|
+
invert_target : bool, optional
|
37
|
+
Whether to invert the target before template matching.
|
38
|
+
rotations: np.ndarray, optional
|
39
|
+
Template rotations to sample. Can be a single (d, d) or a stack (n, d, d)
|
40
|
+
of rotation matrices where d is the dimension of the template.
|
41
|
+
|
42
|
+
Examples
|
43
|
+
--------
|
44
|
+
The following achieves the minimal definition of a :py:class:`MatchingData` instance.
|
45
|
+
|
46
|
+
>>> import numpy as np
|
47
|
+
>>> from tme.matching_data import MatchingData
|
48
|
+
>>> target = np.random.rand(50,40,60)
|
49
|
+
>>> template = target[15:25, 10:20, 30:40]
|
50
|
+
>>> matching_data = MatchingData(target=target, template=template)
|
51
|
+
|
52
|
+
"""
|
53
|
+
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
target: NDArray,
|
57
|
+
template: NDArray,
|
58
|
+
template_mask: NDArray = None,
|
59
|
+
target_mask: NDArray = None,
|
60
|
+
invert_target: bool = False,
|
61
|
+
rotations: NDArray = None,
|
62
|
+
):
|
63
|
+
self.target = target
|
64
|
+
self.target_mask = target_mask
|
65
|
+
|
66
|
+
self.template = template
|
67
|
+
if template_mask is not None:
|
68
|
+
self.template_mask = template_mask
|
69
|
+
|
70
|
+
self.rotations = rotations
|
71
|
+
self._invert_target = invert_target
|
72
|
+
self._translation_offset = tuple(0 for _ in range(len(target.shape)))
|
73
|
+
|
74
|
+
self._set_matching_dimension()
|
75
|
+
|
76
|
+
@staticmethod
|
77
|
+
def _shape_to_slice(shape: Tuple[int]) -> Tuple[slice]:
|
78
|
+
return tuple(slice(0, dim) for dim in shape)
|
79
|
+
|
80
|
+
@classmethod
|
81
|
+
def _slice_to_mesh(cls, slice_variable: Tuple[slice], shape: Tuple[int]) -> NDArray:
|
82
|
+
if slice_variable is None:
|
83
|
+
slice_variable = cls._shape_to_slice(shape)
|
84
|
+
ranges = [range(slc.start, slc.stop) for slc in slice_variable]
|
85
|
+
indices = np.meshgrid(*ranges, sparse=True, indexing="ij")
|
86
|
+
return indices
|
87
|
+
|
88
|
+
@staticmethod
|
89
|
+
def _load_array(arr: BackendArray) -> BackendArray:
|
90
|
+
"""Load ``arr``, if ``arr`` type is a :obj:`numpy.memmap`, reload from disk."""
|
91
|
+
if isinstance(arr, np.memmap):
|
92
|
+
return np.memmap(arr.filename, mode="r", shape=arr.shape, dtype=arr.dtype)
|
93
|
+
return arr
|
94
|
+
|
95
|
+
def subset_array(
|
96
|
+
self,
|
97
|
+
arr: NDArray,
|
98
|
+
arr_slice: Tuple[slice],
|
99
|
+
padding: NDArray,
|
100
|
+
invert: bool = False,
|
101
|
+
) -> NDArray:
|
102
|
+
"""
|
103
|
+
Extract a subset of the input array according to the given slice and
|
104
|
+
apply padding. If the padding exceeds the array dimensions, the
|
105
|
+
padded regions are filled by reflection of the boundaries. Otherwise,
|
106
|
+
the values in ``arr`` are used.
|
107
|
+
|
108
|
+
Parameters
|
109
|
+
----------
|
110
|
+
arr : NDArray
|
111
|
+
The input array from which a subset is extracted.
|
112
|
+
arr_slice : tuple of slice
|
113
|
+
Defines the region of the input array to be extracted.
|
114
|
+
padding : NDArray
|
115
|
+
Padding values for each dimension.
|
116
|
+
invert : bool, optional
|
117
|
+
Whether the returned array should be inverted.
|
118
|
+
|
119
|
+
Returns
|
120
|
+
-------
|
121
|
+
NDArray
|
122
|
+
Subset of the input array with padding applied.
|
123
|
+
"""
|
124
|
+
padding = be.to_numpy_array(padding)
|
125
|
+
padding = np.maximum(padding, 0).astype(int)
|
126
|
+
|
127
|
+
slice_start = np.array([x.start for x in arr_slice], dtype=int)
|
128
|
+
slice_stop = np.array([x.stop for x in arr_slice], dtype=int)
|
129
|
+
|
130
|
+
# We are deviating from our typical right_pad + mod here
|
131
|
+
# because cropping from full convolution mode to target shape
|
132
|
+
# is defined from the perspective of the origin
|
133
|
+
right_pad = np.divide(padding, 2).astype(int)
|
134
|
+
left_pad = np.add(right_pad, np.mod(padding, 2))
|
135
|
+
|
136
|
+
data_voxels_left = np.minimum(slice_start, left_pad)
|
137
|
+
data_voxels_right = np.minimum(
|
138
|
+
np.subtract(arr.shape, slice_stop), right_pad
|
139
|
+
).astype(int)
|
140
|
+
|
141
|
+
arr_start = np.subtract(slice_start, data_voxels_left)
|
142
|
+
arr_stop = np.add(slice_stop, data_voxels_right)
|
143
|
+
arr_slice = tuple(slice(*pos) for pos in zip(arr_start, arr_stop))
|
144
|
+
arr_mesh = self._slice_to_mesh(arr_slice, arr.shape)
|
145
|
+
|
146
|
+
# Note different from joblib memmaps, the memmaps created by
|
147
|
+
# Density are guaranteed to only contain the array of interest
|
148
|
+
if isinstance(arr, Density):
|
149
|
+
if isinstance(arr.data, np.memmap):
|
150
|
+
arr = Density.from_file(arr.data.filename, subset=arr_slice).data
|
151
|
+
else:
|
152
|
+
arr = np.asarray(arr.data[*arr_mesh])
|
153
|
+
else:
|
154
|
+
arr = np.asarray(arr[*arr_mesh])
|
155
|
+
|
156
|
+
padding = tuple(
|
157
|
+
(left, right)
|
158
|
+
for left, right in zip(
|
159
|
+
np.subtract(left_pad, data_voxels_left),
|
160
|
+
np.subtract(right_pad, data_voxels_right),
|
161
|
+
)
|
162
|
+
)
|
163
|
+
# The reflections are later cropped from the scores
|
164
|
+
arr = np.pad(arr, padding, mode="reflect")
|
165
|
+
|
166
|
+
if invert:
|
167
|
+
arr = -arr
|
168
|
+
return arr
|
169
|
+
|
170
|
+
def subset_by_slice(
|
171
|
+
self,
|
172
|
+
target_slice: Tuple[slice] = None,
|
173
|
+
template_slice: Tuple[slice] = None,
|
174
|
+
target_pad: NDArray = None,
|
175
|
+
template_pad: NDArray = None,
|
176
|
+
invert_target: bool = False,
|
177
|
+
) -> "MatchingData":
|
178
|
+
"""
|
179
|
+
Subset class instance based on slices.
|
180
|
+
|
181
|
+
Parameters
|
182
|
+
----------
|
183
|
+
target_slice : tuple of slice, optional
|
184
|
+
Target subset to use, all by default.
|
185
|
+
template_slice : tuple of slice, optional
|
186
|
+
Template subset to use, all by default.
|
187
|
+
target_pad : BackendArray, optional
|
188
|
+
Target padding, zero by default.
|
189
|
+
template_pad : BackendArray, optional
|
190
|
+
Template padding, zero by default.
|
191
|
+
|
192
|
+
Returns
|
193
|
+
-------
|
194
|
+
:py:class:`MatchingData`
|
195
|
+
Newly allocated subset of class instance.
|
196
|
+
|
197
|
+
Examples
|
198
|
+
--------
|
199
|
+
>>> import numpy as np
|
200
|
+
>>> from tme.matching_data import MatchingData
|
201
|
+
>>> target = np.random.rand(50,40,60)
|
202
|
+
>>> template = target[15:25, 10:20, 30:40]
|
203
|
+
>>> matching_data = MatchingData(target=target, template=template)
|
204
|
+
>>> subset = matching_data.subset_by_slice(
|
205
|
+
>>> target_slice=(slice(0, 10), slice(10,20), slice(15,35))
|
206
|
+
>>> )
|
207
|
+
"""
|
208
|
+
if target_slice is None:
|
209
|
+
target_slice = self._shape_to_slice(self._target.shape)
|
210
|
+
if template_slice is None:
|
211
|
+
template_slice = self._shape_to_slice(self._template.shape)
|
212
|
+
|
213
|
+
if target_pad is None:
|
214
|
+
target_pad = np.zeros(len(self._target.shape), dtype=int)
|
215
|
+
if template_pad is None:
|
216
|
+
template_pad = np.zeros(len(self._template.shape), dtype=int)
|
217
|
+
|
218
|
+
target_mask, template_mask = None, None
|
219
|
+
target_subset = self.subset_array(
|
220
|
+
self._target, target_slice, target_pad, invert=self._invert_target
|
221
|
+
)
|
222
|
+
template_subset = self.subset_array(
|
223
|
+
arr=self._template, arr_slice=template_slice, padding=template_pad
|
224
|
+
)
|
225
|
+
if self._target_mask is not None:
|
226
|
+
mask_slice = zip(target_slice, self._target_mask.shape)
|
227
|
+
mask_slice = tuple(x if t != 1 else slice(0, 1) for x, t in mask_slice)
|
228
|
+
target_mask = self.subset_array(
|
229
|
+
arr=self._target_mask, arr_slice=mask_slice, padding=target_pad
|
230
|
+
)
|
231
|
+
if self._template_mask is not None:
|
232
|
+
mask_slice = zip(template_slice, self._template_mask.shape)
|
233
|
+
mask_slice = tuple(x if t != 1 else slice(0, 1) for x, t in mask_slice)
|
234
|
+
template_mask = self.subset_array(
|
235
|
+
arr=self._template_mask, arr_slice=mask_slice, padding=template_pad
|
236
|
+
)
|
237
|
+
|
238
|
+
ret = self.__class__(
|
239
|
+
target=target_subset,
|
240
|
+
template=template_subset,
|
241
|
+
template_mask=template_mask,
|
242
|
+
target_mask=target_mask,
|
243
|
+
rotations=self.rotations,
|
244
|
+
invert_target=self._invert_target,
|
245
|
+
)
|
246
|
+
|
247
|
+
# Deal with splitting offsets
|
248
|
+
mask = np.subtract(1, self._template_batch).astype(bool)
|
249
|
+
target_offset = np.zeros(len(self._output_target_shape), dtype=int)
|
250
|
+
target_offset[mask] = [x.start for x in target_slice]
|
251
|
+
mask = np.subtract(1, self._target_batch).astype(bool)
|
252
|
+
template_offset = np.zeros(len(self._output_template_shape), dtype=int)
|
253
|
+
template_offset[mask] = [x.start for x in template_slice]
|
254
|
+
ret._translation_offset = tuple(x for x in target_offset)
|
255
|
+
|
256
|
+
ret.target_filter = self.target_filter
|
257
|
+
ret.template_filter = self.template_filter
|
258
|
+
|
259
|
+
ret.set_matching_dimension(
|
260
|
+
target_dim=getattr(self, "_target_dim", None),
|
261
|
+
template_dim=getattr(self, "_template_dim", None),
|
262
|
+
)
|
263
|
+
|
264
|
+
return ret
|
265
|
+
|
266
|
+
def to_backend(self):
|
267
|
+
"""
|
268
|
+
Transfer and convert types of internal data arrays to the current backend.
|
269
|
+
|
270
|
+
Examples
|
271
|
+
--------
|
272
|
+
>>> matching_data.to_backend()
|
273
|
+
"""
|
274
|
+
backend_arr = type(be.zeros((1), dtype=be._float_dtype))
|
275
|
+
for attr_name, attr_value in vars(self).items():
|
276
|
+
converted_array = None
|
277
|
+
if isinstance(attr_value, np.ndarray):
|
278
|
+
converted_array = be.to_backend_array(attr_value.copy())
|
279
|
+
elif isinstance(attr_value, backend_arr):
|
280
|
+
converted_array = be.to_backend_array(attr_value)
|
281
|
+
else:
|
282
|
+
continue
|
283
|
+
|
284
|
+
current_dtype = be.get_fundamental_dtype(converted_array)
|
285
|
+
target_dtype = be._fundamental_dtypes[current_dtype]
|
286
|
+
|
287
|
+
# Optional, but scores are float so we avoid casting and potential issues
|
288
|
+
if attr_name in ("_template", "_template_mask", "_target", "_target_mask"):
|
289
|
+
target_dtype = be._float_dtype
|
290
|
+
|
291
|
+
if target_dtype != current_dtype:
|
292
|
+
converted_array = be.astype(converted_array, target_dtype)
|
293
|
+
|
294
|
+
setattr(self, attr_name, converted_array)
|
295
|
+
|
296
|
+
def set_matching_dimension(self, target_dim: int = None, template_dim: int = None):
|
297
|
+
"""
|
298
|
+
Sets matching dimensions for target and template.
|
299
|
+
|
300
|
+
Parameters
|
301
|
+
----------
|
302
|
+
target_dim : int, optional
|
303
|
+
Target batch dimension, None by default.
|
304
|
+
template_dim : int, optional
|
305
|
+
Template batch dimension, None by default.
|
306
|
+
|
307
|
+
Examples
|
308
|
+
--------
|
309
|
+
>>> matching_data.set_matching_dimension(target_dim=0, template_dim=None)
|
310
|
+
|
311
|
+
Notes
|
312
|
+
-----
|
313
|
+
If target and template share a batch dimension, the target will take
|
314
|
+
precendence and the template dimension will be shifted to the right. If target
|
315
|
+
and template have the same dimension, but target specifies batch dimensions,
|
316
|
+
the leftmost template dimensions are assumed to be collapse dimensions.
|
317
|
+
"""
|
318
|
+
target_ndim = len(self._target.shape)
|
319
|
+
_, target_dims = self._compute_batch_dims(target_dim, ndim=target_ndim)
|
320
|
+
template_ndim = len(self._template.shape)
|
321
|
+
_, template_dims = self._compute_batch_dims(template_dim, ndim=template_ndim)
|
322
|
+
|
323
|
+
target_ndim -= len(target_dims)
|
324
|
+
template_ndim -= len(template_dims)
|
325
|
+
|
326
|
+
if target_ndim != template_ndim:
|
327
|
+
raise ValueError(
|
328
|
+
f"Dimension mismatch: Target ({target_ndim}) Template ({template_ndim})."
|
329
|
+
)
|
330
|
+
self._set_matching_dimension(
|
331
|
+
target_dims=target_dims, template_dims=template_dims
|
332
|
+
)
|
333
|
+
|
334
|
+
def _set_matching_dimension(
|
335
|
+
self, target_dims: Tuple[int] = (), template_dims: Tuple[int] = ()
|
336
|
+
):
|
337
|
+
self._target_dim, self._template_dim = target_dims, template_dims
|
338
|
+
|
339
|
+
target_ndim, template_ndim = len(self._target.shape), len(self._template.shape)
|
340
|
+
batch_dims = len(target_dims) + len(template_dims)
|
341
|
+
target_measurement_dims = target_ndim - len(target_dims)
|
342
|
+
collapse_dims = max(
|
343
|
+
template_ndim - len(template_dims) - target_measurement_dims, 0
|
344
|
+
)
|
345
|
+
matching_dims = target_measurement_dims + batch_dims
|
346
|
+
|
347
|
+
target_shape = np.full(shape=matching_dims, fill_value=1, dtype=int)
|
348
|
+
template_shape = np.full(shape=matching_dims, fill_value=1, dtype=int)
|
349
|
+
template_batch = np.full(shape=matching_dims, fill_value=1, dtype=int)
|
350
|
+
target_batch = np.full(shape=matching_dims, fill_value=1, dtype=int)
|
351
|
+
|
352
|
+
target_index, template_index = 0, 0
|
353
|
+
for k in range(matching_dims):
|
354
|
+
target_dim = k - target_index
|
355
|
+
template_dim = k - template_index
|
356
|
+
|
357
|
+
if target_dim in target_dims:
|
358
|
+
target_shape[k] = self._target.shape[target_dim]
|
359
|
+
template_batch[k] = 0
|
360
|
+
if target_index == len(template_dims) and collapse_dims > 0:
|
361
|
+
template_shape[k] = self._template.shape[template_dim]
|
362
|
+
collapse_dims -= 1
|
363
|
+
template_index += 1
|
364
|
+
continue
|
365
|
+
|
366
|
+
if template_dim in template_dims:
|
367
|
+
template_shape[k] = self._template.shape[template_dim]
|
368
|
+
target_batch[k] = 0
|
369
|
+
target_index += 1
|
370
|
+
continue
|
371
|
+
|
372
|
+
target_batch[k] = template_batch[k] = 0
|
373
|
+
if target_dim < target_ndim:
|
374
|
+
target_shape[k] = self._target.shape[target_dim]
|
375
|
+
if template_dim < template_ndim:
|
376
|
+
template_shape[k] = self._template.shape[template_dim]
|
377
|
+
|
378
|
+
batch_mask = np.logical_or(target_batch, template_batch)
|
379
|
+
self._output_target_shape = tuple(int(x) for x in target_shape)
|
380
|
+
self._output_template_shape = tuple(int(x) for x in template_shape)
|
381
|
+
self._batch_mask = tuple(int(x) for x in batch_mask)
|
382
|
+
self._template_batch = tuple(int(x) for x in template_batch)
|
383
|
+
self._target_batch = tuple(int(x) for x in target_batch)
|
384
|
+
|
385
|
+
output_shape = np.add(
|
386
|
+
self._output_target_shape,
|
387
|
+
np.multiply(self._template_batch, self._output_template_shape),
|
388
|
+
)
|
389
|
+
output_shape = np.subtract(output_shape, self._template_batch)
|
390
|
+
self._output_shape = tuple(int(x) for x in output_shape)
|
391
|
+
|
392
|
+
@staticmethod
|
393
|
+
def _compute_batch_dims(batch_dims: Tuple[int], ndim: int) -> Tuple:
|
394
|
+
"""
|
395
|
+
Computes a mask for the batch dimensions and the validated batch dimensions.
|
396
|
+
|
397
|
+
Parameters
|
398
|
+
----------
|
399
|
+
batch_dims : tuple of int
|
400
|
+
A tuple of integers representing the batch dimensions.
|
401
|
+
ndim : int
|
402
|
+
The number of dimensions of the array.
|
403
|
+
|
404
|
+
Returns
|
405
|
+
-------
|
406
|
+
Tuple[ArrayLike, tuple of int]
|
407
|
+
Mask and the corresponding batch dimensions.
|
408
|
+
|
409
|
+
Raises
|
410
|
+
------
|
411
|
+
ValueError
|
412
|
+
If any dimension in batch_dims is not less than ndim.
|
413
|
+
"""
|
414
|
+
mask = np.zeros(ndim, dtype=int)
|
415
|
+
if batch_dims is None:
|
416
|
+
return mask, ()
|
417
|
+
|
418
|
+
if isinstance(batch_dims, int):
|
419
|
+
batch_dims = (batch_dims,)
|
420
|
+
|
421
|
+
for dim in batch_dims:
|
422
|
+
if dim < ndim:
|
423
|
+
mask[dim] = 1
|
424
|
+
continue
|
425
|
+
raise ValueError(f"Batch indices needs to be < {ndim}, got {dim}.")
|
426
|
+
|
427
|
+
return mask, batch_dims
|
428
|
+
|
429
|
+
@staticmethod
|
430
|
+
def _batch_shape(shape: Tuple[int], mask: Tuple[int], keepdims=True) -> Tuple[int]:
|
431
|
+
if keepdims:
|
432
|
+
return tuple(x if y == 0 else 1 for x, y in zip(shape, mask))
|
433
|
+
return tuple(x for x, y in zip(shape, mask) if y == 0)
|
434
|
+
|
435
|
+
@staticmethod
|
436
|
+
def _batch_iter(shape: Tuple[int], mask: Tuple[int]) -> Generator:
|
437
|
+
def _recursive_gen(current_shape, current_mask, current_slices):
|
438
|
+
if not current_shape:
|
439
|
+
yield current_slices
|
440
|
+
return
|
441
|
+
|
442
|
+
if current_mask[0] == 1:
|
443
|
+
for i in range(current_shape[0]):
|
444
|
+
new_slices = current_slices + (slice(i, i + 1),)
|
445
|
+
yield from _recursive_gen(
|
446
|
+
current_shape[1:], current_mask[1:], new_slices
|
447
|
+
)
|
448
|
+
else:
|
449
|
+
new_slices = current_slices + (slice(None),)
|
450
|
+
yield from _recursive_gen(
|
451
|
+
current_shape[1:], current_mask[1:], new_slices
|
452
|
+
)
|
453
|
+
|
454
|
+
return _recursive_gen(shape, mask, ())
|
455
|
+
|
456
|
+
@staticmethod
|
457
|
+
def _batch_axis(mask: Tuple[int]) -> Tuple[int]:
|
458
|
+
return tuple(i for i in range(len(mask)) if mask[i] == 0)
|
459
|
+
|
460
|
+
def target_padding(self, pad_target: bool = False) -> Tuple[int]:
|
461
|
+
"""
|
462
|
+
Computes the padding of the target to the full convolution
|
463
|
+
shape given the registered template.
|
464
|
+
|
465
|
+
Parameters
|
466
|
+
----------
|
467
|
+
pad_target : bool, optional
|
468
|
+
Whether to pad the target, defaults to False.
|
469
|
+
|
470
|
+
Returns
|
471
|
+
-------
|
472
|
+
tuple of int
|
473
|
+
Padding along each dimension.
|
474
|
+
|
475
|
+
Examples
|
476
|
+
--------
|
477
|
+
>>> matching_data.target_padding(pad_target=True)
|
478
|
+
"""
|
479
|
+
padding = np.zeros(len(self._output_target_shape), dtype=int)
|
480
|
+
if pad_target:
|
481
|
+
padding = np.subtract(self._output_template_shape, 1)
|
482
|
+
if hasattr(self, "_target_batch"):
|
483
|
+
padding = np.multiply(padding, np.subtract(1, self._target_batch))
|
484
|
+
|
485
|
+
if hasattr(self, "_template_batch"):
|
486
|
+
padding = tuple(x for x, i in zip(padding, self._template_batch) if i == 0)
|
487
|
+
|
488
|
+
return tuple(int(x) for x in padding)
|
489
|
+
|
490
|
+
@staticmethod
|
491
|
+
def _fourier_padding(
|
492
|
+
target_shape: Tuple[int],
|
493
|
+
template_shape: Tuple[int],
|
494
|
+
pad_fourier: bool,
|
495
|
+
batch_mask: Tuple[int] = None,
|
496
|
+
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
|
497
|
+
fourier_pad = template_shape
|
498
|
+
fourier_shift = np.zeros_like(template_shape)
|
499
|
+
|
500
|
+
if batch_mask is None:
|
501
|
+
batch_mask = np.zeros_like(template_shape)
|
502
|
+
batch_mask = np.asarray(batch_mask)
|
503
|
+
|
504
|
+
if not pad_fourier:
|
505
|
+
fourier_pad = np.ones(len(fourier_pad), dtype=int)
|
506
|
+
fourier_pad = np.multiply(fourier_pad, 1 - batch_mask)
|
507
|
+
fourier_pad = np.add(fourier_pad, batch_mask)
|
508
|
+
|
509
|
+
pad_shape = np.maximum(target_shape, template_shape)
|
510
|
+
ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
|
511
|
+
conv_shape, fast_shape, fast_ft_shape = ret
|
512
|
+
|
513
|
+
template_mod = np.mod(template_shape, 2)
|
514
|
+
if not pad_fourier:
|
515
|
+
fourier_shift = 1 - np.divide(template_shape, 2).astype(int)
|
516
|
+
fourier_shift = np.subtract(fourier_shift, template_mod)
|
517
|
+
|
518
|
+
shape_diff = np.multiply(
|
519
|
+
np.subtract(target_shape, template_shape), 1 - batch_mask
|
520
|
+
)
|
521
|
+
shape_mask = shape_diff < 0
|
522
|
+
if np.sum(shape_mask):
|
523
|
+
shape_shift = np.divide(shape_diff, 2)
|
524
|
+
offset = np.mod(shape_diff, 2)
|
525
|
+
if pad_fourier:
|
526
|
+
offset = -np.subtract(
|
527
|
+
offset,
|
528
|
+
np.logical_and(np.mod(target_shape, 2) == 0, template_mod == 1),
|
529
|
+
)
|
530
|
+
else:
|
531
|
+
warnings.warn(
|
532
|
+
"Template is larger than target and padding is turned off. Consider "
|
533
|
+
"swapping them or activate padding. Correcting the shift for now."
|
534
|
+
)
|
535
|
+
shape_shift = np.multiply(np.add(shape_shift, offset), shape_mask)
|
536
|
+
fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
|
537
|
+
|
538
|
+
fourier_shift = tuple(np.multiply(fourier_shift, 1 - batch_mask).astype(int))
|
539
|
+
|
540
|
+
return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
|
541
|
+
|
542
|
+
def fourier_padding(
|
543
|
+
self, pad_fourier: bool = False
|
544
|
+
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
|
545
|
+
"""
|
546
|
+
Computes efficient shape four Fourier transforms and potential associated shifts.
|
547
|
+
|
548
|
+
Parameters
|
549
|
+
----------
|
550
|
+
pad_fourier : bool, optional
|
551
|
+
If true, returns the shape of the full-convolution defined as sum of target
|
552
|
+
shape and template shape minus one, False by default.
|
553
|
+
|
554
|
+
Returns
|
555
|
+
-------
|
556
|
+
Tuple[tuple of int, tuple of int, tuple of int, tuple of int]
|
557
|
+
Tuple with convolution, forward FT, inverse FT shape and corresponding shift.
|
558
|
+
|
559
|
+
Examples
|
560
|
+
--------
|
561
|
+
>>> conv, fwd, inv, shift = matching_data.fourier_padding(pad_fourier=True)
|
562
|
+
"""
|
563
|
+
return self._fourier_padding(
|
564
|
+
target_shape=be.to_numpy_array(self._output_target_shape),
|
565
|
+
template_shape=be.to_numpy_array(self._output_template_shape),
|
566
|
+
batch_mask=be.to_numpy_array(self._batch_mask),
|
567
|
+
pad_fourier=pad_fourier,
|
568
|
+
)
|
569
|
+
|
570
|
+
def computation_schedule(
|
571
|
+
self,
|
572
|
+
matching_method: str = "FLCSphericalMask",
|
573
|
+
max_cores: int = 1,
|
574
|
+
use_gpu: bool = False,
|
575
|
+
pad_fourier: bool = False,
|
576
|
+
pad_target_edges: bool = False,
|
577
|
+
analyzer_method: str = None,
|
578
|
+
available_memory: int = None,
|
579
|
+
max_splits: int = 256,
|
580
|
+
) -> Tuple[Dict, Tuple]:
|
581
|
+
"""
|
582
|
+
Computes a parallelization schedule for a given template matching operation.
|
583
|
+
|
584
|
+
Parameters
|
585
|
+
----------
|
586
|
+
matching_method : str
|
587
|
+
Matching method to use, default "FLCSphericalMask".
|
588
|
+
max_cores : int, optional
|
589
|
+
Maximum number of CPU cores to use, default 1.
|
590
|
+
use_gpu : bool, optional
|
591
|
+
Whether to utilize GPU acceleration, default False.
|
592
|
+
pad_fourier : bool, optional
|
593
|
+
Apply Fourier padding, default False.
|
594
|
+
pad_target_edges : bool, optional
|
595
|
+
Apply padding to target edges, default False.
|
596
|
+
analyzer_method : str, optional
|
597
|
+
Method used for score analysis, default None.
|
598
|
+
available_memory : int, optional
|
599
|
+
Available memory in bytes. If None, uses all available system memory.
|
600
|
+
max_splits : int, optional
|
601
|
+
Maximum number of splits to consider, default 256.
|
602
|
+
|
603
|
+
Returns
|
604
|
+
-------
|
605
|
+
target_splits : dict
|
606
|
+
Optimal splits for each axis of the target tensor
|
607
|
+
schedule : tuple
|
608
|
+
(n_outer_jobs, n_inner_jobs_per_outer) defining the parallelization schedule
|
609
|
+
"""
|
610
|
+
|
611
|
+
if available_memory is None:
|
612
|
+
available_memory = be.get_available_memory() * be.device_count()
|
613
|
+
|
614
|
+
_template = self._output_template_shape
|
615
|
+
shape1 = np.broadcast_shapes(
|
616
|
+
self._output_target_shape,
|
617
|
+
self._batch_shape(_template, np.subtract(1, self._template_batch)),
|
618
|
+
)
|
619
|
+
|
620
|
+
shape2 = tuple(0 for _ in _template)
|
621
|
+
if pad_fourier:
|
622
|
+
shape2 = np.multiply(_template, np.subtract(1, self._batch_mask))
|
623
|
+
|
624
|
+
padding = tuple(0 for _ in self._output_target_shape)
|
625
|
+
if pad_target_edges:
|
626
|
+
padding = tuple(
|
627
|
+
x if y == 0 else 1 for x, y in zip(_template, self._template_batch)
|
628
|
+
)
|
629
|
+
|
630
|
+
return compute_parallelization_schedule(
|
631
|
+
shape1=shape1,
|
632
|
+
shape2=shape2,
|
633
|
+
shape1_padding=padding,
|
634
|
+
max_cores=max_cores,
|
635
|
+
max_ram=available_memory,
|
636
|
+
matching_method=matching_method,
|
637
|
+
analyzer_method=analyzer_method,
|
638
|
+
backend=be._backend_name,
|
639
|
+
float_nbytes=be.datatype_bytes(be._float_dtype),
|
640
|
+
complex_nbytes=be.datatype_bytes(be._complex_dtype),
|
641
|
+
integer_nbytes=be.datatype_bytes(be._int_dtype),
|
642
|
+
split_only_outer=use_gpu,
|
643
|
+
split_axes=self._target_dim if len(self._target_dim) else None,
|
644
|
+
max_splits=max_splits,
|
645
|
+
)
|
646
|
+
|
647
|
+
@property
|
648
|
+
def rotations(self):
|
649
|
+
"""Return stored rotation matrices."""
|
650
|
+
return self._rotations
|
651
|
+
|
652
|
+
@rotations.setter
|
653
|
+
def rotations(self, rotations: BackendArray):
|
654
|
+
"""
|
655
|
+
Set :py:attr:`MatchingData.rotations`.
|
656
|
+
|
657
|
+
Parameters
|
658
|
+
----------
|
659
|
+
rotations : BackendArray
|
660
|
+
Rotations matrices with shape (d, d) or (n, d, d).
|
661
|
+
"""
|
662
|
+
if rotations is None:
|
663
|
+
print("No rotations provided, assuming identity for now.")
|
664
|
+
rotations = np.eye(len(self._target.shape))
|
665
|
+
|
666
|
+
if rotations.ndim not in (2, 3):
|
667
|
+
raise ValueError("Rotations have to be a rank 2 or 3 array.")
|
668
|
+
elif rotations.ndim == 2:
|
669
|
+
print("Reshaping rotations array to rank 3.")
|
670
|
+
rotations = rotations.reshape(1, *rotations.shape)
|
671
|
+
self._rotations = rotations.astype(np.float32)
|
672
|
+
|
673
|
+
@staticmethod
|
674
|
+
def _get_data(
|
675
|
+
attribute,
|
676
|
+
output_shape: Tuple[int],
|
677
|
+
reverse: bool = False,
|
678
|
+
axis: Tuple[int] = None,
|
679
|
+
):
|
680
|
+
if isinstance(attribute, Density):
|
681
|
+
attribute = attribute.data
|
682
|
+
|
683
|
+
if attribute is not None:
|
684
|
+
if reverse:
|
685
|
+
rev_axis = tuple(i for i in range(attribute.ndim) if i not in axis)
|
686
|
+
attribute = be.reverse(attribute, axis=rev_axis)
|
687
|
+
attribute = attribute.reshape(tuple(int(x) for x in output_shape))
|
688
|
+
|
689
|
+
return attribute
|
690
|
+
|
691
|
+
@property
|
692
|
+
def target(self) -> BackendArray:
|
693
|
+
"""Return the target."""
|
694
|
+
return self._get_data(self._target, self._output_target_shape, False)
|
695
|
+
|
696
|
+
@property
|
697
|
+
def target_mask(self) -> BackendArray:
|
698
|
+
"""Return the target mask."""
|
699
|
+
target_mask = getattr(self, "_target_mask", None)
|
700
|
+
if target_mask is None:
|
701
|
+
return None
|
702
|
+
|
703
|
+
_output_shape = self._output_target_shape
|
704
|
+
if be.size(target_mask) != np.prod(_output_shape):
|
705
|
+
_output_shape = self._batch_shape(_output_shape, self._target_batch, True)
|
706
|
+
|
707
|
+
return self._get_data(target_mask, _output_shape, False)
|
708
|
+
|
709
|
+
@property
|
710
|
+
def template(self) -> BackendArray:
|
711
|
+
"""Return the reversed template."""
|
712
|
+
_output_shape = self._output_template_shape
|
713
|
+
return self._get_data(self._template, _output_shape, True, self._template_dim)
|
714
|
+
|
715
|
+
@property
|
716
|
+
def template_mask(self) -> BackendArray:
|
717
|
+
"""Return the reversed template mask."""
|
718
|
+
template_mask = getattr(self, "_template_mask", None)
|
719
|
+
if template_mask is None:
|
720
|
+
return None
|
721
|
+
|
722
|
+
_output_shape = self._output_template_shape
|
723
|
+
if np.prod([int(i) for i in template_mask.shape]) != np.prod(_output_shape):
|
724
|
+
_output_shape = self._batch_shape(_output_shape, self._template_batch, True)
|
725
|
+
|
726
|
+
return self._get_data(template_mask, _output_shape, True, self._template_dim)
|
727
|
+
|
728
|
+
@target.setter
|
729
|
+
def target(self, arr: NDArray):
|
730
|
+
"""
|
731
|
+
Set :py:attr:`MatchingData.target`.
|
732
|
+
|
733
|
+
Parameters
|
734
|
+
----------
|
735
|
+
arr : NDArray
|
736
|
+
Array to set as the target.
|
737
|
+
"""
|
738
|
+
self._target = arr
|
739
|
+
|
740
|
+
@template.setter
|
741
|
+
def template(self, arr: NDArray):
|
742
|
+
"""
|
743
|
+
Set :py:attr:`MatchingData.template` and initializes
|
744
|
+
:py:attr:`MatchingData.template_mask` to an to an uninformative
|
745
|
+
mask filled with ones if not already defined.
|
746
|
+
|
747
|
+
Parameters
|
748
|
+
----------
|
749
|
+
arr : NDArray
|
750
|
+
Array to set as the template.
|
751
|
+
"""
|
752
|
+
self._template = arr
|
753
|
+
if getattr(self, "_template_mask", None) is None:
|
754
|
+
self._template_mask = np.full(
|
755
|
+
shape=arr.shape, dtype=np.float32, fill_value=1
|
756
|
+
)
|
757
|
+
|
758
|
+
@staticmethod
|
759
|
+
def _set_mask(mask, shape: Tuple[int]):
|
760
|
+
if mask is not None:
|
761
|
+
if np.broadcast_shapes(mask.shape, shape) != shape:
|
762
|
+
raise ValueError("Mask and data shape need to be broadcastable.")
|
763
|
+
return mask
|
764
|
+
|
765
|
+
@target_mask.setter
|
766
|
+
def target_mask(self, arr: NDArray):
|
767
|
+
"""
|
768
|
+
Set :py:attr:`MatchingData.target_mask`.
|
769
|
+
|
770
|
+
Parameters
|
771
|
+
----------
|
772
|
+
arr : NDArray
|
773
|
+
Array to set as the target_mask.
|
774
|
+
"""
|
775
|
+
self._target_mask = self._set_mask(mask=arr, shape=self._target.shape)
|
776
|
+
|
777
|
+
@template_mask.setter
|
778
|
+
def template_mask(self, arr: NDArray):
|
779
|
+
"""
|
780
|
+
Set :py:attr:`MatchingData.template_mask`.
|
781
|
+
|
782
|
+
Parameters
|
783
|
+
----------
|
784
|
+
arr : NDArray
|
785
|
+
Array to set as the template_mask.
|
786
|
+
"""
|
787
|
+
self._template_mask = self._set_mask(mask=arr, shape=self._template.shape)
|
788
|
+
|
789
|
+
@staticmethod
|
790
|
+
def _set_filter(composable_filter) -> Optional[Compose]:
|
791
|
+
if composable_filter is None:
|
792
|
+
return None
|
793
|
+
|
794
|
+
if not isinstance(composable_filter, Compose):
|
795
|
+
warnings.warn(
|
796
|
+
"Custom filters are not sanitized and need to be correctly shaped."
|
797
|
+
)
|
798
|
+
|
799
|
+
return composable_filter
|
800
|
+
|
801
|
+
@property
|
802
|
+
def template_filter(self) -> Optional[Compose]:
|
803
|
+
"""
|
804
|
+
Returns the template filter.
|
805
|
+
|
806
|
+
Returns
|
807
|
+
-------
|
808
|
+
:py:class:`tme.preprocessing.compose.Compose` | BackendArray | None
|
809
|
+
Composable filter, a backend array or None.
|
810
|
+
"""
|
811
|
+
return getattr(self, "_template_filter", None)
|
812
|
+
|
813
|
+
@property
|
814
|
+
def target_filter(self) -> Optional[Compose]:
|
815
|
+
"""
|
816
|
+
Returns the target filter.
|
817
|
+
|
818
|
+
Returns
|
819
|
+
-------
|
820
|
+
:py:class:`tme.preprocessing.compose.Compose` | BackendArray | None
|
821
|
+
Composable filter, a backend array or None.
|
822
|
+
"""
|
823
|
+
return getattr(self, "_target_filter", None)
|
824
|
+
|
825
|
+
@template_filter.setter
|
826
|
+
def template_filter(self, template_filter):
|
827
|
+
self._template_filter = self._set_filter(template_filter)
|
828
|
+
|
829
|
+
@target_filter.setter
|
830
|
+
def target_filter(self, target_filter):
|
831
|
+
self._target_filter = self._set_filter(target_filter)
|
832
|
+
|
833
|
+
def _split_rotations_on_jobs(self, n_jobs: int) -> List[NDArray]:
|
834
|
+
"""
|
835
|
+
Split the rotation matrices into parts based on the number of jobs.
|
836
|
+
|
837
|
+
Parameters
|
838
|
+
----------
|
839
|
+
n_jobs : int
|
840
|
+
Number of jobs for splitting.
|
841
|
+
|
842
|
+
Returns
|
843
|
+
-------
|
844
|
+
list of NDArray
|
845
|
+
List of split rotation matrices.
|
846
|
+
"""
|
847
|
+
nrot_per_job = self.rotations.shape[0] // n_jobs
|
848
|
+
rot_list = []
|
849
|
+
for n in range(n_jobs):
|
850
|
+
init_rot = n * nrot_per_job
|
851
|
+
end_rot = init_rot + nrot_per_job
|
852
|
+
if n == n_jobs - 1:
|
853
|
+
end_rot = None
|
854
|
+
rot_list.append(self.rotations[init_rot:end_rot])
|
855
|
+
return rot_list
|
856
|
+
|
857
|
+
def _free_data(self):
|
858
|
+
"""
|
859
|
+
Dereference data arrays owned by the class instance.
|
860
|
+
"""
|
861
|
+
attrs = ("_target", "_template", "_template_mask", "_target_mask")
|
862
|
+
for attr in attrs:
|
863
|
+
setattr(self, attr, None)
|