pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +1 -5
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/match_template.py +177 -226
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/postprocess.py +69 -47
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocess.py +10 -23
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +98 -28
- pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/METADATA +15 -15
- pytme-0.3.1.dist-info/RECORD +133 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +1 -0
- pytme-0.3.1.dist-info/licenses/LICENSE +339 -0
- scripts/estimate_memory_usage.py +1 -5
- scripts/eval.py +93 -0
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +177 -226
- scripts/match_template_filters.py +1200 -0
- scripts/postprocess.py +69 -47
- scripts/preprocess.py +10 -23
- scripts/preprocessor_gui.py +98 -28
- scripts/pytme_runner.py +1223 -0
- scripts/refine_matches.py +156 -387
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_frequency_filters.py +19 -10
- tests/preprocessing/test_utils.py +18 -0
- tests/test_analyzer.py +122 -122
- tests/test_backends.py +4 -9
- tests/test_density.py +0 -1
- tests/test_matching_cli.py +30 -30
- tests/test_matching_data.py +5 -5
- tests/test_matching_utils.py +11 -61
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +1 -1
- tme/analyzer/_utils.py +5 -8
- tme/analyzer/aggregation.py +28 -9
- tme/analyzer/base.py +25 -36
- tme/analyzer/peaks.py +49 -122
- tme/analyzer/proxy.py +1 -0
- tme/backends/_jax_utils.py +31 -28
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +11 -54
- tme/backends/jax_backend.py +72 -48
- tme/backends/matching_backend.py +6 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +95 -90
- tme/backends/pytorch_backend.py +5 -26
- tme/density.py +7 -10
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/__init__.py +2 -2
- tme/filters/_utils.py +32 -7
- tme/filters/bandpass.py +225 -186
- tme/filters/ctf.py +138 -87
- tme/filters/reconstruction.py +38 -9
- tme/filters/wedge.py +98 -112
- tme/filters/whitening.py +1 -6
- tme/mask.py +341 -0
- tme/matching_data.py +20 -44
- tme/matching_exhaustive.py +46 -56
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +216 -412
- tme/matching_utils.py +82 -424
- tme/memory.py +1 -1
- tme/orientations.py +16 -8
- tme/parser.py +109 -29
- tme/preprocessor.py +2 -2
- tme/rotations.py +1 -1
- pytme-0.3b0.dist-info/RECORD +0 -122
- pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
tme/mask.py
ADDED
@@ -0,0 +1,341 @@
|
|
1
|
+
"""
|
2
|
+
Utility functions for generating template matching masks.
|
3
|
+
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
5
|
+
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
|
+
"""
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
from typing import Tuple, Optional
|
11
|
+
|
12
|
+
from .types import NDArray
|
13
|
+
from scipy.ndimage import gaussian_filter
|
14
|
+
from .matching_utils import rigid_transform
|
15
|
+
|
16
|
+
__all__ = ["elliptical_mask", "tube_mask", "box_mask", "membrane_mask"]
|
17
|
+
|
18
|
+
|
19
|
+
def elliptical_mask(
|
20
|
+
shape: Tuple[int],
|
21
|
+
radius: Tuple[float],
|
22
|
+
center: Optional[Tuple[float]] = None,
|
23
|
+
orientation: Optional[NDArray] = None,
|
24
|
+
sigma_decay: float = 0.0,
|
25
|
+
cutoff_sigma: float = 3,
|
26
|
+
**kwargs,
|
27
|
+
) -> NDArray:
|
28
|
+
"""
|
29
|
+
Creates an ellipsoidal mask.
|
30
|
+
|
31
|
+
Parameters
|
32
|
+
----------
|
33
|
+
shape : tuple of ints
|
34
|
+
Shape of the mask to be created.
|
35
|
+
radius : tuple of floats
|
36
|
+
Radius of the mask.
|
37
|
+
center : tuple of floats, optional
|
38
|
+
Center of the mask, default to shape // 2.
|
39
|
+
orientation : NDArray, optional.
|
40
|
+
Orientation of the mask as rotation matrix with shape (d,d).
|
41
|
+
|
42
|
+
Returns
|
43
|
+
-------
|
44
|
+
NDArray
|
45
|
+
The created ellipsoidal mask.
|
46
|
+
|
47
|
+
Raises
|
48
|
+
------
|
49
|
+
ValueError
|
50
|
+
If the length of center and radius is not one or the same as shape.
|
51
|
+
|
52
|
+
Examples
|
53
|
+
--------
|
54
|
+
>>> from tme.matching_utils import elliptical_mask
|
55
|
+
>>> mask = elliptical_mask(shape=(20,20), radius=(5,5), center=(10,10))
|
56
|
+
"""
|
57
|
+
shape, radius = np.asarray(shape), np.asarray(radius)
|
58
|
+
|
59
|
+
shape = shape.astype(int)
|
60
|
+
if center is None:
|
61
|
+
center = np.divide(shape, 2).astype(int)
|
62
|
+
|
63
|
+
center = np.asarray(center, dtype=np.float32)
|
64
|
+
radius = np.repeat(radius, shape.size // radius.size)
|
65
|
+
center = np.repeat(center, shape.size // center.size)
|
66
|
+
if radius.size != shape.size:
|
67
|
+
raise ValueError("Length of radius has to be either one or match shape.")
|
68
|
+
if center.size != shape.size:
|
69
|
+
raise ValueError("Length of center has to be either one or match shape.")
|
70
|
+
|
71
|
+
n = shape.size
|
72
|
+
center = center.reshape((-1,) + (1,) * n)
|
73
|
+
radius = radius.reshape((-1,) + (1,) * n)
|
74
|
+
|
75
|
+
indices = np.indices(shape, dtype=np.float32) - center
|
76
|
+
if orientation is not None:
|
77
|
+
return_shape = indices.shape
|
78
|
+
indices = indices.reshape(n, -1)
|
79
|
+
rigid_transform(
|
80
|
+
coordinates=indices,
|
81
|
+
rotation_matrix=np.asarray(orientation),
|
82
|
+
out=indices,
|
83
|
+
translation=np.zeros(n),
|
84
|
+
use_geometric_center=False,
|
85
|
+
)
|
86
|
+
indices = indices.reshape(*return_shape)
|
87
|
+
|
88
|
+
dist = np.linalg.norm(indices / radius, axis=0)
|
89
|
+
if sigma_decay > 0:
|
90
|
+
sigma_decay = 2 * (sigma_decay / np.mean(radius)) ** 2
|
91
|
+
mask = np.maximum(0, dist - 1)
|
92
|
+
mask = np.exp(-(mask**2) / sigma_decay)
|
93
|
+
mask *= mask > np.exp(-(cutoff_sigma**2) / 2)
|
94
|
+
else:
|
95
|
+
mask = (dist <= 1).astype(int)
|
96
|
+
return mask
|
97
|
+
|
98
|
+
|
99
|
+
def box_mask(
|
100
|
+
shape: Tuple[int],
|
101
|
+
center: Tuple[int],
|
102
|
+
size: Tuple[int],
|
103
|
+
sigma_decay: float = 0.0,
|
104
|
+
cutoff_sigma: float = 3.0,
|
105
|
+
**kwargs,
|
106
|
+
) -> np.ndarray:
|
107
|
+
"""
|
108
|
+
Creates a box mask centered around the provided center point.
|
109
|
+
|
110
|
+
Parameters
|
111
|
+
----------
|
112
|
+
shape : tuple of ints
|
113
|
+
Shape of the output array.
|
114
|
+
center : tuple of ints
|
115
|
+
Center point coordinates of the box.
|
116
|
+
size : tuple of ints
|
117
|
+
Side length of the box along each axis.
|
118
|
+
|
119
|
+
Returns
|
120
|
+
-------
|
121
|
+
NDArray
|
122
|
+
The created box mask.
|
123
|
+
|
124
|
+
Raises
|
125
|
+
------
|
126
|
+
ValueError
|
127
|
+
If ``shape`` and ``center`` do not have the same length.
|
128
|
+
If ``center`` and ``height`` do not have the same length.
|
129
|
+
"""
|
130
|
+
if len(shape) != len(center) or len(center) != len(size):
|
131
|
+
raise ValueError("The length of shape, center, and height must be consistent.")
|
132
|
+
|
133
|
+
shape = tuple(int(x) for x in shape)
|
134
|
+
center, size = np.array(center, dtype=int), np.array(size, dtype=int)
|
135
|
+
|
136
|
+
half_heights = size // 2
|
137
|
+
starts = np.maximum(center - half_heights, 0)
|
138
|
+
stops = np.minimum(center + half_heights + np.mod(size, 2) + 1, shape)
|
139
|
+
slice_indices = tuple(slice(*coord) for coord in zip(starts, stops))
|
140
|
+
|
141
|
+
out = np.zeros(shape)
|
142
|
+
out[slice_indices] = 1
|
143
|
+
|
144
|
+
if sigma_decay > 0:
|
145
|
+
mask_filter = gaussian_filter(
|
146
|
+
out.astype(np.float32), sigma=sigma_decay, truncate=cutoff_sigma
|
147
|
+
)
|
148
|
+
out = np.add(out, (1 - out) * mask_filter)
|
149
|
+
out *= out > np.exp(-(cutoff_sigma**2) / 2)
|
150
|
+
return out
|
151
|
+
|
152
|
+
|
153
|
+
def tube_mask(
|
154
|
+
shape: Tuple[int],
|
155
|
+
symmetry_axis: int,
|
156
|
+
center: Tuple[int],
|
157
|
+
inner_radius: float,
|
158
|
+
outer_radius: float,
|
159
|
+
height: int,
|
160
|
+
sigma_decay: float = 0.0,
|
161
|
+
cutoff_sigma: float = 3.0,
|
162
|
+
**kwargs,
|
163
|
+
) -> NDArray:
|
164
|
+
"""
|
165
|
+
Creates a tube mask.
|
166
|
+
|
167
|
+
Parameters
|
168
|
+
----------
|
169
|
+
shape : tuple
|
170
|
+
Shape of the mask to be created.
|
171
|
+
symmetry_axis : int
|
172
|
+
The axis of symmetry for the tube.
|
173
|
+
base_center : tuple
|
174
|
+
Center of the tube.
|
175
|
+
inner_radius : float
|
176
|
+
Inner radius of the tube.
|
177
|
+
outer_radius : float
|
178
|
+
Outer radius of the tube.
|
179
|
+
height : int
|
180
|
+
Height of the tube.
|
181
|
+
|
182
|
+
Returns
|
183
|
+
-------
|
184
|
+
NDArray
|
185
|
+
The created tube mask.
|
186
|
+
|
187
|
+
Raises
|
188
|
+
------
|
189
|
+
ValueError
|
190
|
+
If ``inner_radius`` is larger than ``outer_radius``.
|
191
|
+
If ``height`` is larger than the symmetry axis.
|
192
|
+
If ``base_center`` and ``shape`` do not have the same length.
|
193
|
+
"""
|
194
|
+
if inner_radius > outer_radius:
|
195
|
+
raise ValueError("inner_radius should be smaller than outer_radius.")
|
196
|
+
|
197
|
+
if height > shape[symmetry_axis]:
|
198
|
+
raise ValueError(f"Height can be no larger than {shape[symmetry_axis]}.")
|
199
|
+
|
200
|
+
if symmetry_axis > len(shape):
|
201
|
+
raise ValueError(f"symmetry_axis can be not larger than {len(shape)}.")
|
202
|
+
|
203
|
+
if len(center) != len(shape):
|
204
|
+
raise ValueError("shape and base_center need to have the same length.")
|
205
|
+
|
206
|
+
shape = tuple(int(x) for x in shape)
|
207
|
+
circle_shape = tuple(b for ix, b in enumerate(shape) if ix != symmetry_axis)
|
208
|
+
circle_center = tuple(b for ix, b in enumerate(center) if ix != symmetry_axis)
|
209
|
+
|
210
|
+
inner_circle = np.zeros(circle_shape)
|
211
|
+
outer_circle = np.zeros_like(inner_circle)
|
212
|
+
if inner_radius > 0:
|
213
|
+
inner_circle = elliptical_mask(
|
214
|
+
shape=circle_shape,
|
215
|
+
radius=inner_radius,
|
216
|
+
center=circle_center,
|
217
|
+
sigma_decay=sigma_decay,
|
218
|
+
cutoff_sigma=cutoff_sigma,
|
219
|
+
)
|
220
|
+
if outer_radius > 0:
|
221
|
+
outer_circle = elliptical_mask(
|
222
|
+
shape=circle_shape,
|
223
|
+
radius=outer_radius,
|
224
|
+
center=circle_center,
|
225
|
+
sigma_decay=sigma_decay,
|
226
|
+
cutoff_sigma=cutoff_sigma,
|
227
|
+
)
|
228
|
+
circle = outer_circle - inner_circle
|
229
|
+
circle = np.expand_dims(circle, axis=symmetry_axis)
|
230
|
+
|
231
|
+
center = center[symmetry_axis]
|
232
|
+
start_idx = int(center - height // 2)
|
233
|
+
stop_idx = int(center + height // 2 + height % 2)
|
234
|
+
|
235
|
+
start_idx, stop_idx = max(start_idx, 0), min(stop_idx, shape[symmetry_axis])
|
236
|
+
|
237
|
+
slice_indices = tuple(
|
238
|
+
slice(None) if i != symmetry_axis else slice(start_idx, stop_idx)
|
239
|
+
for i in range(len(shape))
|
240
|
+
)
|
241
|
+
tube = np.zeros(shape)
|
242
|
+
tube[slice_indices] = circle
|
243
|
+
|
244
|
+
return tube
|
245
|
+
|
246
|
+
|
247
|
+
def membrane_mask(
|
248
|
+
shape: Tuple[int],
|
249
|
+
radius: float,
|
250
|
+
thickness: float,
|
251
|
+
separation: float,
|
252
|
+
symmetry_axis: int = 2,
|
253
|
+
center: Optional[Tuple[float]] = None,
|
254
|
+
sigma_decay: float = 0.5,
|
255
|
+
cutoff_sigma: float = 3,
|
256
|
+
**kwargs,
|
257
|
+
) -> NDArray:
|
258
|
+
"""
|
259
|
+
Creates a membrane mask consisting of two parallel disks with Gaussian intensity profile.
|
260
|
+
Uses efficient broadcasting approach: flat disk mask × height profile.
|
261
|
+
|
262
|
+
Parameters
|
263
|
+
----------
|
264
|
+
shape : tuple of ints
|
265
|
+
Shape of the mask to be created.
|
266
|
+
radius : float
|
267
|
+
Radius of the membrane disks.
|
268
|
+
thickness : float
|
269
|
+
Thickness of each disk in the membrane.
|
270
|
+
separation : float
|
271
|
+
Distance between the centers of the two disks.
|
272
|
+
symmetry_axis : int, optional
|
273
|
+
The axis perpendicular to the membrane disks, defaults to 2.
|
274
|
+
center : tuple of floats, optional
|
275
|
+
Center of the membrane (midpoint between the two disks), defaults to shape // 2.
|
276
|
+
sigma_decay : float, optional
|
277
|
+
Controls edge sharpness relative to radius, defaults to 0.5.
|
278
|
+
cutoff_sigma : float, optional
|
279
|
+
Cutoff for height profile in standard deviations, defaults to 3.
|
280
|
+
|
281
|
+
Returns
|
282
|
+
-------
|
283
|
+
NDArray
|
284
|
+
The created membrane mask with Gaussian intensity profile.
|
285
|
+
|
286
|
+
Raises
|
287
|
+
------
|
288
|
+
ValueError
|
289
|
+
If ``thickness`` is negative.
|
290
|
+
If ``separation`` is negative.
|
291
|
+
If ``center`` and ``shape`` do not have the same length.
|
292
|
+
If ``symmetry_axis`` is out of bounds.
|
293
|
+
|
294
|
+
Examples
|
295
|
+
--------
|
296
|
+
>>> from tme.matching_utils import membrane_mask
|
297
|
+
>>> mask = membrane_mask(shape=(50,50,50), radius=10, thickness=2, separation=15)
|
298
|
+
"""
|
299
|
+
shape = np.asarray(shape, dtype=int)
|
300
|
+
|
301
|
+
if center is None:
|
302
|
+
center = np.divide(shape, 2).astype(float)
|
303
|
+
|
304
|
+
center = np.asarray(center, dtype=np.float32)
|
305
|
+
center = np.repeat(center, shape.size // center.size)
|
306
|
+
|
307
|
+
if thickness < 0:
|
308
|
+
raise ValueError("thickness must be non-negative.")
|
309
|
+
if separation < 0:
|
310
|
+
raise ValueError("separation must be non-negative.")
|
311
|
+
if symmetry_axis >= len(shape):
|
312
|
+
raise ValueError(f"symmetry_axis must be less than {len(shape)}.")
|
313
|
+
if center.size != shape.size:
|
314
|
+
raise ValueError("Length of center has to be either one or match shape.")
|
315
|
+
|
316
|
+
disk_mask = elliptical_mask(
|
317
|
+
shape=[x for i, x in enumerate(shape) if i != symmetry_axis],
|
318
|
+
radius=radius,
|
319
|
+
sigma_decay=sigma_decay,
|
320
|
+
cutoff_sigma=cutoff_sigma,
|
321
|
+
)
|
322
|
+
|
323
|
+
axial_coord = np.arange(shape[symmetry_axis]) - center[symmetry_axis]
|
324
|
+
height_profile = np.zeros((shape[symmetry_axis],), dtype=np.float32)
|
325
|
+
for leaflet_pos in [-separation / 2, separation / 2]:
|
326
|
+
leaflet_profile = np.exp(
|
327
|
+
-((axial_coord - leaflet_pos) ** 2) / (2 * (thickness / 3) ** 2)
|
328
|
+
)
|
329
|
+
cutoff_threshold = np.exp(-(cutoff_sigma**2) / 2)
|
330
|
+
leaflet_profile *= leaflet_profile > cutoff_threshold
|
331
|
+
|
332
|
+
height_profile = np.maximum(height_profile, leaflet_profile)
|
333
|
+
|
334
|
+
disk_mask = disk_mask.reshape(
|
335
|
+
[x if i != symmetry_axis else 1 for i, x in enumerate(shape)]
|
336
|
+
)
|
337
|
+
height_profile = height_profile.reshape(
|
338
|
+
[1 if i != symmetry_axis else x for i, x in enumerate(shape)]
|
339
|
+
)
|
340
|
+
|
341
|
+
return disk_mask * height_profile
|
tme/matching_data.py
CHANGED
@@ -128,11 +128,8 @@ class MatchingData:
|
|
128
128
|
slice_start = np.array([x.start for x in arr_slice], dtype=int)
|
129
129
|
slice_stop = np.array([x.stop for x in arr_slice], dtype=int)
|
130
130
|
|
131
|
-
|
132
|
-
|
133
|
-
# is defined from the perspective of the origin
|
134
|
-
right_pad = np.divide(padding, 2).astype(int)
|
135
|
-
left_pad = np.add(right_pad, np.mod(padding, 2))
|
131
|
+
left_pad = np.divide(padding, 2).astype(int)
|
132
|
+
right_pad = np.add(left_pad, np.mod(padding, 2))
|
136
133
|
|
137
134
|
data_voxels_left = np.minimum(slice_start, left_pad)
|
138
135
|
data_voxels_right = np.minimum(
|
@@ -175,7 +172,7 @@ class MatchingData:
|
|
175
172
|
target_pad: NDArray = None,
|
176
173
|
template_pad: NDArray = None,
|
177
174
|
invert_target: bool = False,
|
178
|
-
) -> "MatchingData":
|
175
|
+
) -> Tuple["MatchingData", Tuple]:
|
179
176
|
"""
|
180
177
|
Subset class instance based on slices.
|
181
178
|
|
@@ -194,6 +191,8 @@ class MatchingData:
|
|
194
191
|
-------
|
195
192
|
:py:class:`MatchingData`
|
196
193
|
Newly allocated subset of class instance.
|
194
|
+
Tuple
|
195
|
+
Translation offset to merge analyzers.
|
197
196
|
|
198
197
|
Examples
|
199
198
|
--------
|
@@ -252,7 +251,8 @@ class MatchingData:
|
|
252
251
|
mask = np.subtract(1, self._target_batch).astype(bool)
|
253
252
|
template_offset = np.zeros(len(self._output_template_shape), dtype=int)
|
254
253
|
template_offset[mask] = [x.start for x in template_slice]
|
255
|
-
|
254
|
+
|
255
|
+
translation_offset = tuple(x for x in target_offset)
|
256
256
|
|
257
257
|
ret.target_filter = self.target_filter
|
258
258
|
ret.template_filter = self.template_filter
|
@@ -262,7 +262,7 @@ class MatchingData:
|
|
262
262
|
template_dim=getattr(self, "_template_dim", None),
|
263
263
|
)
|
264
264
|
|
265
|
-
return ret
|
265
|
+
return ret, translation_offset
|
266
266
|
|
267
267
|
def to_backend(self):
|
268
268
|
"""
|
@@ -296,7 +296,7 @@ class MatchingData:
|
|
296
296
|
|
297
297
|
def set_matching_dimension(self, target_dim: int = None, template_dim: int = None):
|
298
298
|
"""
|
299
|
-
Sets matching dimensions for target and template.
|
299
|
+
Sets matching batch dimensions for target and template.
|
300
300
|
|
301
301
|
Parameters
|
302
302
|
----------
|
@@ -323,11 +323,6 @@ class MatchingData:
|
|
323
323
|
|
324
324
|
target_ndim -= len(target_dims)
|
325
325
|
template_ndim -= len(template_dims)
|
326
|
-
|
327
|
-
if target_ndim != template_ndim:
|
328
|
-
raise ValueError(
|
329
|
-
f"Dimension mismatch: Target ({target_ndim}) Template ({template_ndim})."
|
330
|
-
)
|
331
326
|
self._set_matching_dimension(
|
332
327
|
target_dims=target_dims, template_dims=template_dims
|
333
328
|
)
|
@@ -492,29 +487,26 @@ class MatchingData:
|
|
492
487
|
def _fourier_padding(
|
493
488
|
target_shape: Tuple[int],
|
494
489
|
template_shape: Tuple[int],
|
495
|
-
pad_fourier: bool,
|
496
490
|
batch_mask: Tuple[int] = None,
|
491
|
+
**kwargs,
|
497
492
|
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
|
498
|
-
fourier_pad = template_shape
|
499
|
-
fourier_shift = np.zeros_like(template_shape)
|
500
|
-
|
501
493
|
if batch_mask is None:
|
502
494
|
batch_mask = np.zeros_like(template_shape)
|
503
495
|
batch_mask = np.asarray(batch_mask)
|
504
496
|
|
505
|
-
|
506
|
-
fourier_pad = np.ones(len(fourier_pad), dtype=int)
|
497
|
+
fourier_pad = np.ones(len(template_shape), dtype=int)
|
507
498
|
fourier_pad = np.multiply(fourier_pad, 1 - batch_mask)
|
508
499
|
fourier_pad = np.add(fourier_pad, batch_mask)
|
509
500
|
|
501
|
+
# Avoid padding batch dimensions
|
510
502
|
pad_shape = np.maximum(target_shape, template_shape)
|
503
|
+
pad_shape = np.maximum(pad_shape, np.multiply(1 - batch_mask, pad_shape))
|
511
504
|
ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
|
512
505
|
conv_shape, fast_shape, fast_ft_shape = ret
|
513
506
|
|
514
507
|
template_mod = np.mod(template_shape, 2)
|
515
|
-
|
516
|
-
|
517
|
-
fourier_shift = np.subtract(fourier_shift, template_mod)
|
508
|
+
fourier_shift = 1 - np.divide(template_shape, 2).astype(int)
|
509
|
+
fourier_shift = np.subtract(fourier_shift, template_mod)
|
518
510
|
|
519
511
|
shape_diff = np.multiply(
|
520
512
|
np.subtract(target_shape, template_shape), 1 - batch_mask
|
@@ -523,35 +515,20 @@ class MatchingData:
|
|
523
515
|
if np.sum(shape_mask):
|
524
516
|
shape_shift = np.divide(shape_diff, 2)
|
525
517
|
offset = np.mod(shape_diff, 2)
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
)
|
531
|
-
else:
|
532
|
-
warnings.warn(
|
533
|
-
"Template is larger than target and padding is turned off. Consider "
|
534
|
-
"swapping them or activate padding. Correcting the shift for now."
|
535
|
-
)
|
518
|
+
warnings.warn(
|
519
|
+
"Template is larger than target and padding is turned off. Consider "
|
520
|
+
"swapping them or activate padding. Correcting the shift for now."
|
521
|
+
)
|
536
522
|
shape_shift = np.multiply(np.add(shape_shift, offset), shape_mask)
|
537
523
|
fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
|
538
524
|
|
539
525
|
fourier_shift = tuple(np.multiply(fourier_shift, 1 - batch_mask).astype(int))
|
540
|
-
|
541
526
|
return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
|
542
527
|
|
543
|
-
def fourier_padding(
|
544
|
-
self, pad_fourier: bool = False
|
545
|
-
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
|
528
|
+
def fourier_padding(self, **kwargs) -> Tuple:
|
546
529
|
"""
|
547
530
|
Computes efficient shape four Fourier transforms and potential associated shifts.
|
548
531
|
|
549
|
-
Parameters
|
550
|
-
----------
|
551
|
-
pad_fourier : bool, optional
|
552
|
-
If true, returns the shape of the full-convolution defined as sum of target
|
553
|
-
shape and template shape minus one, False by default.
|
554
|
-
|
555
532
|
Returns
|
556
533
|
-------
|
557
534
|
Tuple[tuple of int, tuple of int, tuple of int, tuple of int]
|
@@ -565,7 +542,6 @@ class MatchingData:
|
|
565
542
|
target_shape=be.to_numpy_array(self._output_target_shape),
|
566
543
|
template_shape=be.to_numpy_array(self._output_template_shape),
|
567
544
|
batch_mask=be.to_numpy_array(self._batch_mask),
|
568
|
-
pad_fourier=pad_fourier,
|
569
545
|
)
|
570
546
|
|
571
547
|
def computation_schedule(
|