pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/match_template.py +28 -39
- {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/postprocess.py +35 -21
- {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/preprocessor_gui.py +95 -24
- pytme-0.3.1.post1.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/METADATA +5 -7
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/RECORD +55 -48
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +28 -39
- scripts/postprocess.py +35 -21
- scripts/preprocessor_gui.py +95 -24
- scripts/pytme_runner.py +644 -190
- scripts/refine_matches.py +156 -386
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_utils.py +18 -0
- tests/test_analyzer.py +2 -3
- tests/test_backends.py +3 -9
- tests/test_density.py +0 -1
- tests/test_extensions.py +0 -1
- tests/test_matching_utils.py +10 -60
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/_utils.py +4 -4
- tme/analyzer/aggregation.py +35 -15
- tme/analyzer/peaks.py +11 -10
- tme/backends/_jax_utils.py +26 -13
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +16 -55
- tme/backends/jax_backend.py +76 -37
- tme/backends/matching_backend.py +17 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +71 -65
- tme/backends/pytorch_backend.py +1 -26
- tme/density.py +2 -6
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/ctf.py +22 -21
- tme/filters/wedge.py +10 -7
- tme/mask.py +341 -0
- tme/matching_data.py +31 -19
- tme/matching_exhaustive.py +37 -47
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +229 -411
- tme/matching_utils.py +73 -422
- tme/memory.py +1 -1
- tme/orientations.py +13 -8
- tme/rotations.py +1 -1
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
- {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/preprocess.py +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/entry_points.txt +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/top_level.txt +0 -0
tme/mask.py
ADDED
@@ -0,0 +1,341 @@
|
|
1
|
+
"""
|
2
|
+
Utility functions for generating template matching masks.
|
3
|
+
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
5
|
+
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
|
+
"""
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
from typing import Tuple, Optional
|
11
|
+
|
12
|
+
from .types import NDArray
|
13
|
+
from scipy.ndimage import gaussian_filter
|
14
|
+
from .matching_utils import rigid_transform
|
15
|
+
|
16
|
+
__all__ = ["elliptical_mask", "tube_mask", "box_mask", "membrane_mask"]
|
17
|
+
|
18
|
+
|
19
|
+
def elliptical_mask(
|
20
|
+
shape: Tuple[int],
|
21
|
+
radius: Tuple[float],
|
22
|
+
center: Optional[Tuple[float]] = None,
|
23
|
+
orientation: Optional[NDArray] = None,
|
24
|
+
sigma_decay: float = 0.0,
|
25
|
+
cutoff_sigma: float = 3,
|
26
|
+
**kwargs,
|
27
|
+
) -> NDArray:
|
28
|
+
"""
|
29
|
+
Creates an ellipsoidal mask.
|
30
|
+
|
31
|
+
Parameters
|
32
|
+
----------
|
33
|
+
shape : tuple of ints
|
34
|
+
Shape of the mask to be created.
|
35
|
+
radius : tuple of floats
|
36
|
+
Radius of the mask.
|
37
|
+
center : tuple of floats, optional
|
38
|
+
Center of the mask, default to shape // 2.
|
39
|
+
orientation : NDArray, optional.
|
40
|
+
Orientation of the mask as rotation matrix with shape (d,d).
|
41
|
+
|
42
|
+
Returns
|
43
|
+
-------
|
44
|
+
NDArray
|
45
|
+
The created ellipsoidal mask.
|
46
|
+
|
47
|
+
Raises
|
48
|
+
------
|
49
|
+
ValueError
|
50
|
+
If the length of center and radius is not one or the same as shape.
|
51
|
+
|
52
|
+
Examples
|
53
|
+
--------
|
54
|
+
>>> from tme.matching_utils import elliptical_mask
|
55
|
+
>>> mask = elliptical_mask(shape=(20,20), radius=(5,5), center=(10,10))
|
56
|
+
"""
|
57
|
+
shape, radius = np.asarray(shape), np.asarray(radius)
|
58
|
+
|
59
|
+
shape = shape.astype(int)
|
60
|
+
if center is None:
|
61
|
+
center = np.divide(shape, 2).astype(int)
|
62
|
+
|
63
|
+
center = np.asarray(center, dtype=np.float32)
|
64
|
+
radius = np.repeat(radius, shape.size // radius.size)
|
65
|
+
center = np.repeat(center, shape.size // center.size)
|
66
|
+
if radius.size != shape.size:
|
67
|
+
raise ValueError("Length of radius has to be either one or match shape.")
|
68
|
+
if center.size != shape.size:
|
69
|
+
raise ValueError("Length of center has to be either one or match shape.")
|
70
|
+
|
71
|
+
n = shape.size
|
72
|
+
center = center.reshape((-1,) + (1,) * n)
|
73
|
+
radius = radius.reshape((-1,) + (1,) * n)
|
74
|
+
|
75
|
+
indices = np.indices(shape, dtype=np.float32) - center
|
76
|
+
if orientation is not None:
|
77
|
+
return_shape = indices.shape
|
78
|
+
indices = indices.reshape(n, -1)
|
79
|
+
rigid_transform(
|
80
|
+
coordinates=indices,
|
81
|
+
rotation_matrix=np.asarray(orientation),
|
82
|
+
out=indices,
|
83
|
+
translation=np.zeros(n),
|
84
|
+
use_geometric_center=False,
|
85
|
+
)
|
86
|
+
indices = indices.reshape(*return_shape)
|
87
|
+
|
88
|
+
dist = np.linalg.norm(indices / radius, axis=0)
|
89
|
+
if sigma_decay > 0:
|
90
|
+
sigma_decay = 2 * (sigma_decay / np.mean(radius)) ** 2
|
91
|
+
mask = np.maximum(0, dist - 1)
|
92
|
+
mask = np.exp(-(mask**2) / sigma_decay)
|
93
|
+
mask *= mask > np.exp(-(cutoff_sigma**2) / 2)
|
94
|
+
else:
|
95
|
+
mask = (dist <= 1).astype(int)
|
96
|
+
return mask
|
97
|
+
|
98
|
+
|
99
|
+
def box_mask(
|
100
|
+
shape: Tuple[int],
|
101
|
+
center: Tuple[int],
|
102
|
+
size: Tuple[int],
|
103
|
+
sigma_decay: float = 0.0,
|
104
|
+
cutoff_sigma: float = 3.0,
|
105
|
+
**kwargs,
|
106
|
+
) -> np.ndarray:
|
107
|
+
"""
|
108
|
+
Creates a box mask centered around the provided center point.
|
109
|
+
|
110
|
+
Parameters
|
111
|
+
----------
|
112
|
+
shape : tuple of ints
|
113
|
+
Shape of the output array.
|
114
|
+
center : tuple of ints
|
115
|
+
Center point coordinates of the box.
|
116
|
+
size : tuple of ints
|
117
|
+
Side length of the box along each axis.
|
118
|
+
|
119
|
+
Returns
|
120
|
+
-------
|
121
|
+
NDArray
|
122
|
+
The created box mask.
|
123
|
+
|
124
|
+
Raises
|
125
|
+
------
|
126
|
+
ValueError
|
127
|
+
If ``shape`` and ``center`` do not have the same length.
|
128
|
+
If ``center`` and ``height`` do not have the same length.
|
129
|
+
"""
|
130
|
+
if len(shape) != len(center) or len(center) != len(size):
|
131
|
+
raise ValueError("The length of shape, center, and height must be consistent.")
|
132
|
+
|
133
|
+
shape = tuple(int(x) for x in shape)
|
134
|
+
center, size = np.array(center, dtype=int), np.array(size, dtype=int)
|
135
|
+
|
136
|
+
half_heights = size // 2
|
137
|
+
starts = np.maximum(center - half_heights, 0)
|
138
|
+
stops = np.minimum(center + half_heights + np.mod(size, 2) + 1, shape)
|
139
|
+
slice_indices = tuple(slice(*coord) for coord in zip(starts, stops))
|
140
|
+
|
141
|
+
out = np.zeros(shape)
|
142
|
+
out[slice_indices] = 1
|
143
|
+
|
144
|
+
if sigma_decay > 0:
|
145
|
+
mask_filter = gaussian_filter(
|
146
|
+
out.astype(np.float32), sigma=sigma_decay, truncate=cutoff_sigma
|
147
|
+
)
|
148
|
+
out = np.add(out, (1 - out) * mask_filter)
|
149
|
+
out *= out > np.exp(-(cutoff_sigma**2) / 2)
|
150
|
+
return out
|
151
|
+
|
152
|
+
|
153
|
+
def tube_mask(
|
154
|
+
shape: Tuple[int],
|
155
|
+
symmetry_axis: int,
|
156
|
+
center: Tuple[int],
|
157
|
+
inner_radius: float,
|
158
|
+
outer_radius: float,
|
159
|
+
height: int,
|
160
|
+
sigma_decay: float = 0.0,
|
161
|
+
cutoff_sigma: float = 3.0,
|
162
|
+
**kwargs,
|
163
|
+
) -> NDArray:
|
164
|
+
"""
|
165
|
+
Creates a tube mask.
|
166
|
+
|
167
|
+
Parameters
|
168
|
+
----------
|
169
|
+
shape : tuple
|
170
|
+
Shape of the mask to be created.
|
171
|
+
symmetry_axis : int
|
172
|
+
The axis of symmetry for the tube.
|
173
|
+
base_center : tuple
|
174
|
+
Center of the tube.
|
175
|
+
inner_radius : float
|
176
|
+
Inner radius of the tube.
|
177
|
+
outer_radius : float
|
178
|
+
Outer radius of the tube.
|
179
|
+
height : int
|
180
|
+
Height of the tube.
|
181
|
+
|
182
|
+
Returns
|
183
|
+
-------
|
184
|
+
NDArray
|
185
|
+
The created tube mask.
|
186
|
+
|
187
|
+
Raises
|
188
|
+
------
|
189
|
+
ValueError
|
190
|
+
If ``inner_radius`` is larger than ``outer_radius``.
|
191
|
+
If ``height`` is larger than the symmetry axis.
|
192
|
+
If ``base_center`` and ``shape`` do not have the same length.
|
193
|
+
"""
|
194
|
+
if inner_radius > outer_radius:
|
195
|
+
raise ValueError("inner_radius should be smaller than outer_radius.")
|
196
|
+
|
197
|
+
if height > shape[symmetry_axis]:
|
198
|
+
raise ValueError(f"Height can be no larger than {shape[symmetry_axis]}.")
|
199
|
+
|
200
|
+
if symmetry_axis > len(shape):
|
201
|
+
raise ValueError(f"symmetry_axis can be not larger than {len(shape)}.")
|
202
|
+
|
203
|
+
if len(center) != len(shape):
|
204
|
+
raise ValueError("shape and base_center need to have the same length.")
|
205
|
+
|
206
|
+
shape = tuple(int(x) for x in shape)
|
207
|
+
circle_shape = tuple(b for ix, b in enumerate(shape) if ix != symmetry_axis)
|
208
|
+
circle_center = tuple(b for ix, b in enumerate(center) if ix != symmetry_axis)
|
209
|
+
|
210
|
+
inner_circle = np.zeros(circle_shape)
|
211
|
+
outer_circle = np.zeros_like(inner_circle)
|
212
|
+
if inner_radius > 0:
|
213
|
+
inner_circle = elliptical_mask(
|
214
|
+
shape=circle_shape,
|
215
|
+
radius=inner_radius,
|
216
|
+
center=circle_center,
|
217
|
+
sigma_decay=sigma_decay,
|
218
|
+
cutoff_sigma=cutoff_sigma,
|
219
|
+
)
|
220
|
+
if outer_radius > 0:
|
221
|
+
outer_circle = elliptical_mask(
|
222
|
+
shape=circle_shape,
|
223
|
+
radius=outer_radius,
|
224
|
+
center=circle_center,
|
225
|
+
sigma_decay=sigma_decay,
|
226
|
+
cutoff_sigma=cutoff_sigma,
|
227
|
+
)
|
228
|
+
circle = outer_circle - inner_circle
|
229
|
+
circle = np.expand_dims(circle, axis=symmetry_axis)
|
230
|
+
|
231
|
+
center = center[symmetry_axis]
|
232
|
+
start_idx = int(center - height // 2)
|
233
|
+
stop_idx = int(center + height // 2 + height % 2)
|
234
|
+
|
235
|
+
start_idx, stop_idx = max(start_idx, 0), min(stop_idx, shape[symmetry_axis])
|
236
|
+
|
237
|
+
slice_indices = tuple(
|
238
|
+
slice(None) if i != symmetry_axis else slice(start_idx, stop_idx)
|
239
|
+
for i in range(len(shape))
|
240
|
+
)
|
241
|
+
tube = np.zeros(shape)
|
242
|
+
tube[slice_indices] = circle
|
243
|
+
|
244
|
+
return tube
|
245
|
+
|
246
|
+
|
247
|
+
def membrane_mask(
|
248
|
+
shape: Tuple[int],
|
249
|
+
radius: float,
|
250
|
+
thickness: float,
|
251
|
+
separation: float,
|
252
|
+
symmetry_axis: int = 2,
|
253
|
+
center: Optional[Tuple[float]] = None,
|
254
|
+
sigma_decay: float = 0.5,
|
255
|
+
cutoff_sigma: float = 3,
|
256
|
+
**kwargs,
|
257
|
+
) -> NDArray:
|
258
|
+
"""
|
259
|
+
Creates a membrane mask consisting of two parallel disks with Gaussian intensity profile.
|
260
|
+
Uses efficient broadcasting approach: flat disk mask × height profile.
|
261
|
+
|
262
|
+
Parameters
|
263
|
+
----------
|
264
|
+
shape : tuple of ints
|
265
|
+
Shape of the mask to be created.
|
266
|
+
radius : float
|
267
|
+
Radius of the membrane disks.
|
268
|
+
thickness : float
|
269
|
+
Thickness of each disk in the membrane.
|
270
|
+
separation : float
|
271
|
+
Distance between the centers of the two disks.
|
272
|
+
symmetry_axis : int, optional
|
273
|
+
The axis perpendicular to the membrane disks, defaults to 2.
|
274
|
+
center : tuple of floats, optional
|
275
|
+
Center of the membrane (midpoint between the two disks), defaults to shape // 2.
|
276
|
+
sigma_decay : float, optional
|
277
|
+
Controls edge sharpness relative to radius, defaults to 0.5.
|
278
|
+
cutoff_sigma : float, optional
|
279
|
+
Cutoff for height profile in standard deviations, defaults to 3.
|
280
|
+
|
281
|
+
Returns
|
282
|
+
-------
|
283
|
+
NDArray
|
284
|
+
The created membrane mask with Gaussian intensity profile.
|
285
|
+
|
286
|
+
Raises
|
287
|
+
------
|
288
|
+
ValueError
|
289
|
+
If ``thickness`` is negative.
|
290
|
+
If ``separation`` is negative.
|
291
|
+
If ``center`` and ``shape`` do not have the same length.
|
292
|
+
If ``symmetry_axis`` is out of bounds.
|
293
|
+
|
294
|
+
Examples
|
295
|
+
--------
|
296
|
+
>>> from tme.matching_utils import membrane_mask
|
297
|
+
>>> mask = membrane_mask(shape=(50,50,50), radius=10, thickness=2, separation=15)
|
298
|
+
"""
|
299
|
+
shape = np.asarray(shape, dtype=int)
|
300
|
+
|
301
|
+
if center is None:
|
302
|
+
center = np.divide(shape, 2).astype(float)
|
303
|
+
|
304
|
+
center = np.asarray(center, dtype=np.float32)
|
305
|
+
center = np.repeat(center, shape.size // center.size)
|
306
|
+
|
307
|
+
if thickness < 0:
|
308
|
+
raise ValueError("thickness must be non-negative.")
|
309
|
+
if separation < 0:
|
310
|
+
raise ValueError("separation must be non-negative.")
|
311
|
+
if symmetry_axis >= len(shape):
|
312
|
+
raise ValueError(f"symmetry_axis must be less than {len(shape)}.")
|
313
|
+
if center.size != shape.size:
|
314
|
+
raise ValueError("Length of center has to be either one or match shape.")
|
315
|
+
|
316
|
+
disk_mask = elliptical_mask(
|
317
|
+
shape=[x for i, x in enumerate(shape) if i != symmetry_axis],
|
318
|
+
radius=radius,
|
319
|
+
sigma_decay=sigma_decay,
|
320
|
+
cutoff_sigma=cutoff_sigma,
|
321
|
+
)
|
322
|
+
|
323
|
+
axial_coord = np.arange(shape[symmetry_axis]) - center[symmetry_axis]
|
324
|
+
height_profile = np.zeros((shape[symmetry_axis],), dtype=np.float32)
|
325
|
+
for leaflet_pos in [-separation / 2, separation / 2]:
|
326
|
+
leaflet_profile = np.exp(
|
327
|
+
-((axial_coord - leaflet_pos) ** 2) / (2 * (thickness / 3) ** 2)
|
328
|
+
)
|
329
|
+
cutoff_threshold = np.exp(-(cutoff_sigma**2) / 2)
|
330
|
+
leaflet_profile *= leaflet_profile > cutoff_threshold
|
331
|
+
|
332
|
+
height_profile = np.maximum(height_profile, leaflet_profile)
|
333
|
+
|
334
|
+
disk_mask = disk_mask.reshape(
|
335
|
+
[x if i != symmetry_axis else 1 for i, x in enumerate(shape)]
|
336
|
+
)
|
337
|
+
height_profile = height_profile.reshape(
|
338
|
+
[1 if i != symmetry_axis else x for i, x in enumerate(shape)]
|
339
|
+
)
|
340
|
+
|
341
|
+
return disk_mask * height_profile
|
tme/matching_data.py
CHANGED
@@ -128,11 +128,8 @@ class MatchingData:
|
|
128
128
|
slice_start = np.array([x.start for x in arr_slice], dtype=int)
|
129
129
|
slice_stop = np.array([x.stop for x in arr_slice], dtype=int)
|
130
130
|
|
131
|
-
|
132
|
-
|
133
|
-
# is defined from the perspective of the origin
|
134
|
-
right_pad = np.divide(padding, 2).astype(int)
|
135
|
-
left_pad = np.add(right_pad, np.mod(padding, 2))
|
131
|
+
left_pad = np.divide(padding, 2).astype(int)
|
132
|
+
right_pad = np.add(left_pad, np.mod(padding, 2))
|
136
133
|
|
137
134
|
data_voxels_left = np.minimum(slice_start, left_pad)
|
138
135
|
data_voxels_right = np.minimum(
|
@@ -253,7 +250,7 @@ class MatchingData:
|
|
253
250
|
target_offset[mask] = [x.start for x in target_slice]
|
254
251
|
mask = np.subtract(1, self._target_batch).astype(bool)
|
255
252
|
template_offset = np.zeros(len(self._output_template_shape), dtype=int)
|
256
|
-
template_offset[mask] = [x.start for x
|
253
|
+
template_offset[mask] = [x.start for x in template_slice]
|
257
254
|
|
258
255
|
translation_offset = tuple(x for x in target_offset)
|
259
256
|
|
@@ -299,7 +296,7 @@ class MatchingData:
|
|
299
296
|
|
300
297
|
def set_matching_dimension(self, target_dim: int = None, template_dim: int = None):
|
301
298
|
"""
|
302
|
-
Sets matching dimensions for target and template.
|
299
|
+
Sets matching batch dimensions for target and template.
|
303
300
|
|
304
301
|
Parameters
|
305
302
|
----------
|
@@ -490,8 +487,8 @@ class MatchingData:
|
|
490
487
|
def _fourier_padding(
|
491
488
|
target_shape: Tuple[int],
|
492
489
|
template_shape: Tuple[int],
|
493
|
-
pad_target: bool = False,
|
494
490
|
batch_mask: Tuple[int] = None,
|
491
|
+
**kwargs,
|
495
492
|
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
|
496
493
|
if batch_mask is None:
|
497
494
|
batch_mask = np.zeros_like(template_shape)
|
@@ -503,7 +500,7 @@ class MatchingData:
|
|
503
500
|
|
504
501
|
# Avoid padding batch dimensions
|
505
502
|
pad_shape = np.maximum(target_shape, template_shape)
|
506
|
-
pad_shape = np.maximum(
|
503
|
+
pad_shape = np.maximum(pad_shape, np.multiply(1 - batch_mask, pad_shape))
|
507
504
|
ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
|
508
505
|
conv_shape, fast_shape, fast_ft_shape = ret
|
509
506
|
|
@@ -525,21 +522,13 @@ class MatchingData:
|
|
525
522
|
shape_shift = np.multiply(np.add(shape_shift, offset), shape_mask)
|
526
523
|
fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
|
527
524
|
|
528
|
-
if pad_target:
|
529
|
-
fourier_shift = np.subtract(fourier_shift, np.subtract(1, template_mod))
|
530
|
-
|
531
525
|
fourier_shift = tuple(np.multiply(fourier_shift, 1 - batch_mask).astype(int))
|
532
526
|
return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
|
533
527
|
|
534
|
-
def fourier_padding(self,
|
528
|
+
def fourier_padding(self, **kwargs) -> Tuple:
|
535
529
|
"""
|
536
530
|
Computes efficient shape four Fourier transforms and potential associated shifts.
|
537
531
|
|
538
|
-
Parameters
|
539
|
-
----------
|
540
|
-
pad_target : bool, optional
|
541
|
-
Whether the target has been padded to the full convolution shape.
|
542
|
-
|
543
532
|
Returns
|
544
533
|
-------
|
545
534
|
Tuple[tuple of int, tuple of int, tuple of int, tuple of int]
|
@@ -553,9 +542,32 @@ class MatchingData:
|
|
553
542
|
target_shape=be.to_numpy_array(self._output_target_shape),
|
554
543
|
template_shape=be.to_numpy_array(self._output_template_shape),
|
555
544
|
batch_mask=be.to_numpy_array(self._batch_mask),
|
556
|
-
pad_target=pad_target,
|
557
545
|
)
|
558
546
|
|
547
|
+
def _score_mask(self, fast_shape: Tuple[int], shift: Tuple[int]) -> BackendArray:
|
548
|
+
"""
|
549
|
+
Create a boolean mask to exclude scores derived from padding in template matching.
|
550
|
+
"""
|
551
|
+
padding = self.target_padding(True)
|
552
|
+
offset = tuple(x // 2 for x in padding)
|
553
|
+
shape = tuple(y - x for x, y in zip(padding, self.target.shape))
|
554
|
+
|
555
|
+
subset = []
|
556
|
+
for i in range(len(offset)):
|
557
|
+
if self._batch_mask[i]:
|
558
|
+
subset.append(slice(None))
|
559
|
+
else:
|
560
|
+
subset.append(slice(offset[i], offset[i] + shape[i]))
|
561
|
+
|
562
|
+
score_mask = np.zeros(fast_shape, dtype=bool)
|
563
|
+
score_mask[tuple(subset)] = 1
|
564
|
+
score_mask = np.roll(
|
565
|
+
score_mask,
|
566
|
+
shift=tuple(-x for x in shift),
|
567
|
+
axis=tuple(i for i in range(len(shift))),
|
568
|
+
)
|
569
|
+
return be.to_backend_array(score_mask)
|
570
|
+
|
559
571
|
def computation_schedule(
|
560
572
|
self,
|
561
573
|
matching_method: str = "FLCSphericalMask",
|
tme/matching_exhaustive.py
CHANGED
@@ -40,7 +40,7 @@ def _setup_template_filter_apply_target_filter(
|
|
40
40
|
matching_data: MatchingData,
|
41
41
|
fast_shape: Tuple[int],
|
42
42
|
fast_ft_shape: Tuple[int],
|
43
|
-
pad_template_filter: bool =
|
43
|
+
pad_template_filter: bool = False,
|
44
44
|
):
|
45
45
|
target_filter = None
|
46
46
|
backend_arr = type(be.zeros((1), dtype=be._float_dtype))
|
@@ -146,11 +146,10 @@ def scan(
|
|
146
146
|
matching_data: MatchingData,
|
147
147
|
matching_setup: Callable,
|
148
148
|
matching_score: Callable,
|
149
|
-
|
150
|
-
callback_class: CallbackClass = None,
|
149
|
+
callback_class: CallbackClass,
|
151
150
|
callback_class_args: Dict = {},
|
151
|
+
n_jobs: int = 4,
|
152
152
|
pad_target: bool = True,
|
153
|
-
pad_template_filter: bool = True,
|
154
153
|
interpolation_order: int = 3,
|
155
154
|
jobs_per_callback_class: int = 8,
|
156
155
|
shm_handler=None,
|
@@ -172,20 +171,22 @@ def scan(
|
|
172
171
|
Function pointer to scoring function.
|
173
172
|
n_jobs : int, optional
|
174
173
|
Number of parallel jobs. Default is 4.
|
175
|
-
callback_class : type
|
174
|
+
callback_class : type
|
176
175
|
Analyzer class pointer to operate on computed scores.
|
177
176
|
callback_class_args : dict, optional
|
178
177
|
Arguments passed to the callback_class. Default is an empty dictionary.
|
179
178
|
pad_target: bool, optional
|
180
179
|
Whether to pad target to the full convolution shape.
|
181
|
-
pad_template_filter: bool, optional
|
182
|
-
Whether to pad potential template filters to the full convolution shape.
|
183
180
|
interpolation_order : int, optional
|
184
181
|
Order of spline interpolation for rotations.
|
185
182
|
jobs_per_callback_class : int, optional
|
186
183
|
Number of jobs a callback_class instance is shared between, 8 by default.
|
187
184
|
shm_handler : type, optional
|
188
185
|
Manager for shared memory objects, None by default.
|
186
|
+
target_slice : tuple of slice, optional
|
187
|
+
Target subset to process.
|
188
|
+
template_slice : tuple of slice, optional
|
189
|
+
Template subset to process.
|
189
190
|
|
190
191
|
Returns
|
191
192
|
-------
|
@@ -220,13 +221,17 @@ def scan(
|
|
220
221
|
template_shape = matching_data._batch_shape(
|
221
222
|
matching_data.template.shape, matching_data._template_batch
|
222
223
|
)
|
223
|
-
conv, fwd, inv, shift = matching_data.fourier_padding(
|
224
|
+
conv, fwd, inv, shift = matching_data.fourier_padding()
|
225
|
+
|
226
|
+
score_mask = be.full(shape=(1,), fill_value=1, dtype=bool)
|
227
|
+
if pad_target:
|
228
|
+
score_mask = matching_data._score_mask(fwd, shift)
|
224
229
|
|
225
230
|
template_filter = _setup_template_filter_apply_target_filter(
|
226
231
|
matching_data=matching_data,
|
227
232
|
fast_shape=fwd,
|
228
233
|
fast_ft_shape=inv,
|
229
|
-
pad_template_filter=
|
234
|
+
pad_template_filter=False,
|
230
235
|
)
|
231
236
|
|
232
237
|
default_callback_args = {
|
@@ -240,11 +245,10 @@ def scan(
|
|
240
245
|
"thread_safe": n_jobs > 1,
|
241
246
|
"convolution_mode": "valid" if pad_target else "same",
|
242
247
|
"shm_handler": shm_handler,
|
243
|
-
"only_unique_rotations": True,
|
244
248
|
"aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
|
245
249
|
"n_rotations": matching_data.rotations.shape[0],
|
250
|
+
"inversion_mapping": n_jobs == 1,
|
246
251
|
}
|
247
|
-
callback_class_args["inversion_mapping"] = n_jobs == 1
|
248
252
|
default_callback_args.update(callback_class_args)
|
249
253
|
|
250
254
|
setup = matching_setup(
|
@@ -254,22 +258,14 @@ def scan(
|
|
254
258
|
fast_ft_shape=inv,
|
255
259
|
shm_handler=shm_handler,
|
256
260
|
)
|
257
|
-
setup["interpolation_order"] = interpolation_order
|
258
|
-
setup["template_filter"] = be.to_sharedarr(template_filter, shm_handler)
|
259
261
|
|
260
262
|
matching_data._free_data()
|
261
|
-
be.free_cache()
|
262
|
-
|
263
263
|
n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
|
264
264
|
callback_classes = [
|
265
|
-
(
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
shm_handler=shm_handler if n_jobs > 1 else None,
|
270
|
-
)
|
271
|
-
if callback_class
|
272
|
-
else None
|
265
|
+
SharedAnalyzerProxy(
|
266
|
+
callback_class,
|
267
|
+
default_callback_args,
|
268
|
+
shm_handler=shm_handler if n_jobs > 1 else None,
|
273
269
|
)
|
274
270
|
for _ in range(n_callback_classes)
|
275
271
|
]
|
@@ -277,35 +273,33 @@ def scan(
|
|
277
273
|
delayed(_wrap_backend(matching_score))(
|
278
274
|
backend_name=be._backend_name,
|
279
275
|
backend_args=be._backend_args,
|
276
|
+
fast_shape=fwd,
|
277
|
+
fast_ft_shape=inv,
|
280
278
|
rotations=rotation,
|
281
279
|
callback=callback_classes[index % n_callback_classes],
|
280
|
+
interpolation_order=interpolation_order,
|
281
|
+
template_filter=be.to_sharedarr(template_filter, shm_handler),
|
282
|
+
score_mask=be.to_sharedarr(score_mask, shm_handler),
|
282
283
|
**setup,
|
283
284
|
)
|
284
285
|
for index, rotation in enumerate(matching_data._split_rotations_on_jobs(n_jobs))
|
285
286
|
)
|
286
|
-
callbacks = [
|
287
|
-
callback.result(**default_callback_args)
|
288
|
-
for callback in ret[:n_callback_classes]
|
289
|
-
if callback
|
290
|
-
]
|
291
287
|
be.free_cache()
|
292
288
|
|
293
|
-
|
294
|
-
|
295
|
-
return ret
|
289
|
+
callbacks = [x.result(**default_callback_args) for x in ret[:n_callback_classes]]
|
290
|
+
return callback_class.merge(callbacks, **default_callback_args)
|
296
291
|
|
297
292
|
|
298
293
|
def scan_subsets(
|
299
294
|
matching_data: MatchingData,
|
300
295
|
matching_score: Callable,
|
301
296
|
matching_setup: Callable,
|
302
|
-
callback_class: CallbackClass
|
297
|
+
callback_class: CallbackClass,
|
303
298
|
callback_class_args: Dict = {},
|
304
299
|
job_schedule: Tuple[int] = (1, 1),
|
305
300
|
target_splits: Dict = {},
|
306
301
|
template_splits: Dict = {},
|
307
302
|
pad_target_edges: bool = False,
|
308
|
-
pad_template_filter: bool = True,
|
309
303
|
interpolation_order: int = 3,
|
310
304
|
jobs_per_callback_class: int = 8,
|
311
305
|
backend_name: str = None,
|
@@ -325,7 +319,7 @@ def scan_subsets(
|
|
325
319
|
Function pointer to setup function.
|
326
320
|
matching_score : type
|
327
321
|
Function pointer to scoring function.
|
328
|
-
callback_class : type
|
322
|
+
callback_class : type
|
329
323
|
Analyzer class pointer to operate on computed scores.
|
330
324
|
callback_class_args : dict, optional
|
331
325
|
Arguments passed to the callback_class. Default is an empty dictionary.
|
@@ -341,8 +335,6 @@ def scan_subsets(
|
|
341
335
|
See :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
|
342
336
|
pad_target_edges : bool, optional
|
343
337
|
Pad the target boundaries to avoid edge effects.
|
344
|
-
pad_template_filter: bool, optional
|
345
|
-
Whether to pad potential template filters to the full convolution shape.
|
346
338
|
interpolation_order : int, optional
|
347
339
|
Order of spline interpolation for rotations.
|
348
340
|
jobs_per_callback_class : int, optional
|
@@ -424,18 +416,22 @@ def scan_subsets(
|
|
424
416
|
)
|
425
417
|
splits = tuple(product(target_splits, template_splits))
|
426
418
|
|
419
|
+
kwargs = {
|
420
|
+
"matching_data": matching_data,
|
421
|
+
"callback_class": callback_class,
|
422
|
+
"callback_class_args": callback_class_args,
|
423
|
+
}
|
424
|
+
|
427
425
|
outer_jobs, inner_jobs = job_schedule
|
428
426
|
if be._backend_name == "jax":
|
429
427
|
func = be.scan
|
430
428
|
|
431
429
|
corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
|
432
430
|
results = func(
|
433
|
-
matching_data=matching_data,
|
434
431
|
splits=splits,
|
435
432
|
n_jobs=outer_jobs,
|
436
433
|
rotate_mask=matching_score != corr_scoring,
|
437
|
-
|
438
|
-
callback_class_args=callback_class_args,
|
434
|
+
**kwargs,
|
439
435
|
)
|
440
436
|
else:
|
441
437
|
results = Parallel(n_jobs=outer_jobs, verbose=verbose)(
|
@@ -443,26 +439,20 @@ def scan_subsets(
|
|
443
439
|
delayed(_wrap_backend(scan))(
|
444
440
|
backend_name=be._backend_name,
|
445
441
|
backend_args=be._backend_args,
|
446
|
-
matching_data=matching_data,
|
447
442
|
matching_score=matching_score,
|
448
443
|
matching_setup=matching_setup,
|
449
444
|
n_jobs=inner_jobs,
|
450
|
-
callback_class=callback_class,
|
451
|
-
callback_class_args=callback_class_args,
|
452
445
|
interpolation_order=interpolation_order,
|
453
446
|
pad_target=pad_target_edges,
|
454
447
|
gpu_index=index % outer_jobs,
|
455
|
-
pad_template_filter=pad_template_filter,
|
456
448
|
target_slice=target_split,
|
457
449
|
template_slice=template_split,
|
450
|
+
**kwargs,
|
458
451
|
)
|
459
452
|
for index, (target_split, template_split) in enumerate(splits)
|
460
453
|
]
|
461
454
|
)
|
462
|
-
|
463
|
-
if callback_class is not None:
|
464
|
-
return callback_class.merge(results, **callback_class_args)
|
465
|
-
return None
|
455
|
+
return callback_class.merge(results, **callback_class_args)
|
466
456
|
|
467
457
|
|
468
458
|
def register_matching_exhaustive(
|
tme/matching_optimization.py
CHANGED
@@ -1104,7 +1104,8 @@ def create_score_object(score: str, **kwargs) -> object:
|
|
1104
1104
|
Examples
|
1105
1105
|
--------
|
1106
1106
|
>>> from tme import Density
|
1107
|
-
>>> from tme.
|
1107
|
+
>>> from tme.mask import create_mask
|
1108
|
+
>>> from tme.matching_utils import euler_to_rotationmatrix
|
1108
1109
|
>>> from tme.matching_optimization import CrossCorrelation, optimize_match
|
1109
1110
|
>>> translation, rotation = (5, -2, 7), (5, -10, 2)
|
1110
1111
|
>>> target = create_mask(
|