pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/match_template.py +28 -39
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/postprocess.py +23 -10
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +95 -24
- pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/METADATA +5 -5
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/RECORD +53 -46
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +28 -39
- scripts/postprocess.py +23 -10
- scripts/preprocessor_gui.py +95 -24
- scripts/pytme_runner.py +644 -190
- scripts/refine_matches.py +156 -386
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_utils.py +18 -0
- tests/test_backends.py +3 -9
- tests/test_density.py +0 -1
- tests/test_matching_utils.py +10 -60
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/_utils.py +4 -4
- tme/analyzer/aggregation.py +13 -3
- tme/analyzer/peaks.py +11 -10
- tme/backends/_jax_utils.py +15 -13
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +5 -44
- tme/backends/jax_backend.py +58 -37
- tme/backends/matching_backend.py +6 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +68 -65
- tme/backends/pytorch_backend.py +1 -26
- tme/density.py +2 -6
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/ctf.py +22 -21
- tme/filters/wedge.py +10 -7
- tme/mask.py +341 -0
- tme/matching_data.py +7 -19
- tme/matching_exhaustive.py +34 -47
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +206 -411
- tme/matching_utils.py +73 -422
- tme/memory.py +1 -1
- tme/orientations.py +4 -6
- tme/rotations.py +1 -1
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocess.py +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
tme/filters/ctf.py
CHANGED
@@ -36,7 +36,7 @@ class CTF(ComposableFilter):
|
|
36
36
|
|
37
37
|
#: The shape of the to-be created mask.
|
38
38
|
shape: Tuple[int] = None
|
39
|
-
#: The defocus
|
39
|
+
#: The defocus in x direction (in units of sampling rate).
|
40
40
|
defocus_x: Tuple[float] = None
|
41
41
|
#: The tilt angles.
|
42
42
|
angles: Tuple[float] = None
|
@@ -164,7 +164,7 @@ class CTF(ComposableFilter):
|
|
164
164
|
shape : tuple of int
|
165
165
|
The shape of the CTF.
|
166
166
|
defocus_x : tuple of float
|
167
|
-
The defocus
|
167
|
+
The defocus in x direction (in units of sampling rate).
|
168
168
|
angles : tuple of float
|
169
169
|
The tilt angles.
|
170
170
|
opening_axis : int, optional
|
@@ -178,7 +178,7 @@ class CTF(ComposableFilter):
|
|
178
178
|
defocus_angle : tuple of float, optional
|
179
179
|
The defocus angle in radians, defaults to 0.
|
180
180
|
defocus_y : tuple of float, optional
|
181
|
-
The defocus
|
181
|
+
The defocus in x direction (in units of sampling rate).
|
182
182
|
correct_defocus_gradient : bool, optional
|
183
183
|
Whether to correct defocus gradient, defaults to False.
|
184
184
|
sampling_rate : tuple of float, optional
|
@@ -219,14 +219,12 @@ class CTF(ComposableFilter):
|
|
219
219
|
corrected_tilt_axis -= 1
|
220
220
|
|
221
221
|
for index, angle in enumerate(angles):
|
222
|
-
defocus_x, defocus_y = defoci_x[index], defoci_y[index]
|
223
|
-
|
224
222
|
correction = correct_defocus_gradient and angle is not None
|
225
223
|
chi = create_ctf(
|
226
224
|
angle=angle,
|
227
225
|
shape=ctf_shape,
|
228
|
-
defocus_x=
|
229
|
-
defocus_y=
|
226
|
+
defocus_x=defoci_x[index],
|
227
|
+
defocus_y=defoci_y[index],
|
230
228
|
sampling_rate=sampling_rate,
|
231
229
|
acceleration_voltage=acceleration_voltage[index],
|
232
230
|
correct_defocus_gradient=correction,
|
@@ -243,12 +241,10 @@ class CTF(ComposableFilter):
|
|
243
241
|
stack[index] = chi
|
244
242
|
|
245
243
|
# Avoid contrast inversion
|
246
|
-
np.negative(stack, out=stack)
|
244
|
+
stack = np.negative(stack, out=stack)
|
247
245
|
if flip_phase:
|
248
|
-
np.abs(stack, out=stack)
|
249
|
-
|
250
|
-
stack = be.to_backend_array(np.squeeze(stack))
|
251
|
-
return stack
|
246
|
+
stack = np.abs(stack, out=stack)
|
247
|
+
return be.to_backend_array(np.squeeze(stack))
|
252
248
|
|
253
249
|
|
254
250
|
class CTFReconstructed(CTF):
|
@@ -281,7 +277,7 @@ class CTFReconstructed(CTF):
|
|
281
277
|
shape : tuple of int
|
282
278
|
The shape of the CTF.
|
283
279
|
defocus_x : tuple of float
|
284
|
-
The defocus
|
280
|
+
The defocus in x direction in units of sampling rate.
|
285
281
|
opening_axis : int, optional
|
286
282
|
The axis around which the wedge is opened, defaults to 2.
|
287
283
|
amplitude_contrast : float, optional
|
@@ -291,7 +287,7 @@ class CTFReconstructed(CTF):
|
|
291
287
|
defocus_angle : tuple of float, optional
|
292
288
|
The defocus angle in radians, defaults to 0.
|
293
289
|
defocus_y : tuple of float, optional
|
294
|
-
The defocus
|
290
|
+
The defocus in y direction in units of sampling rate.
|
295
291
|
sampling_rate : tuple of float, optional
|
296
292
|
The sampling rate, defaults to 1.
|
297
293
|
acceleration_voltage : float, optional
|
@@ -321,18 +317,15 @@ class CTFReconstructed(CTF):
|
|
321
317
|
defocus_angle=defocus_angle,
|
322
318
|
amplitude_contrast=amplitude_contrast,
|
323
319
|
)
|
324
|
-
stack = shift_fourier(data=stack, shape_is_real_fourier=False)
|
325
|
-
|
326
320
|
# Avoid contrast inversion
|
327
321
|
np.negative(stack, out=stack)
|
328
322
|
if flip_phase:
|
329
323
|
np.abs(stack, out=stack)
|
330
324
|
|
331
|
-
stack =
|
325
|
+
stack = shift_fourier(data=stack, shape_is_real_fourier=False)
|
332
326
|
if return_real_fourier:
|
333
327
|
stack = crop_real_fourier(stack)
|
334
|
-
|
335
|
-
return stack
|
328
|
+
return be.to_backend_array(np.squeeze(stack))
|
336
329
|
|
337
330
|
|
338
331
|
def _from_xml(filename: str) -> Dict:
|
@@ -436,6 +429,9 @@ def _from_ctffind(filename: str) -> Dict:
|
|
436
429
|
output[key] = np.array(output[key])
|
437
430
|
|
438
431
|
output["additional_phase_shift"] = np.degrees(output["additional_phase_shift"])
|
432
|
+
cs = output.get("spherical_aberration")
|
433
|
+
if cs is not None:
|
434
|
+
output["spherical_aberration"] = float(cs) * 1e7
|
439
435
|
return output
|
440
436
|
|
441
437
|
|
@@ -566,7 +562,7 @@ def create_ctf(
|
|
566
562
|
amplitude_contrast : float, optional
|
567
563
|
Amplitude contrast of microscope, defaults to 0.07.
|
568
564
|
spherical_aberration : float, optional
|
569
|
-
Spherical aberration of microscope in
|
565
|
+
Spherical aberration of microscope in units of sampling rate.
|
570
566
|
angle : float, optional
|
571
567
|
Assume the created CTF is a projection over opening_axis observed at angle.
|
572
568
|
opening_axis : int, optional
|
@@ -590,10 +586,14 @@ def create_ctf(
|
|
590
586
|
electron_wavelength = _compute_electron_wavelength(acceleration_voltage)
|
591
587
|
electron_wavelength /= sampling_rate
|
592
588
|
aberration = (spherical_aberration / sampling_rate) * electron_wavelength**2
|
589
|
+
|
590
|
+
defocus_x = defocus_x / sampling_rate if defocus_x is not None else None
|
591
|
+
defocus_y = defocus_y / sampling_rate if defocus_y is not None else None
|
593
592
|
if correct_defocus_gradient or defocus_y is not None:
|
594
593
|
if len(shape) < 2:
|
595
594
|
raise ValueError(f"Length of shape needs to be at least 2, got {shape}")
|
596
595
|
|
596
|
+
# Axial distance from grid center in multiples of sampling rate
|
597
597
|
sampling = tuple(float(x) for x in np.divide(sampling_rate, shape))
|
598
598
|
grid = fftfreqn(
|
599
599
|
shape=shape,
|
@@ -619,6 +619,7 @@ def create_ctf(
|
|
619
619
|
defocus_sum = np.add(defocus_x, defocus_y)
|
620
620
|
defocus_difference = np.subtract(defocus_x, defocus_y)
|
621
621
|
|
622
|
+
# Reusing grid, but in principle pure frequencies would suffice
|
622
623
|
angular_grid = np.arctan2(grid[1], grid[0])
|
623
624
|
defocus_difference = np.multiply(
|
624
625
|
defocus_difference,
|
@@ -627,7 +628,7 @@ def create_ctf(
|
|
627
628
|
defocus_x = np.add(defocus_sum, defocus_difference)
|
628
629
|
defocus_x *= 0.5
|
629
630
|
|
630
|
-
frequency_grid = fftfreqn(shape, sampling_rate=
|
631
|
+
frequency_grid = fftfreqn(shape, sampling_rate=1, compute_euclidean_norm=True)
|
631
632
|
if angle is not None and opening_axis is not None and full_shape is not None:
|
632
633
|
frequency_grid = frequency_grid_at_angle(
|
633
634
|
shape=full_shape,
|
tme/filters/wedge.py
CHANGED
@@ -15,7 +15,7 @@ import numpy as np
|
|
15
15
|
from ..types import NDArray
|
16
16
|
from ..backends import backend as be
|
17
17
|
from .compose import ComposableFilter
|
18
|
-
from ..matching_utils import
|
18
|
+
from ..matching_utils import center_slice
|
19
19
|
from ..rotations import euler_to_rotationmatrix
|
20
20
|
from ..parser import XMLParser, StarParser, MDOCParser
|
21
21
|
from ._utils import (
|
@@ -207,11 +207,10 @@ class Wedge(ComposableFilter):
|
|
207
207
|
)
|
208
208
|
sigma = np.sqrt(self.weights[index] * 4 / (8 * np.pi**2))
|
209
209
|
sigma = -2 * np.pi**2 * sigma**2
|
210
|
-
np.square(frequency_grid, out=frequency_grid)
|
211
|
-
np.multiply(sigma, frequency_grid, out=frequency_grid)
|
212
|
-
np.exp(frequency_grid, out=frequency_grid)
|
213
|
-
np.multiply(frequency_grid, np.cos(np.radians(angle))
|
214
|
-
wedges[index] = frequency_grid
|
210
|
+
frequency_grid = np.square(frequency_grid, out=frequency_grid)
|
211
|
+
frequency_grid = np.multiply(sigma, frequency_grid, out=frequency_grid)
|
212
|
+
frequency_grid = np.exp(frequency_grid, out=frequency_grid)
|
213
|
+
wedges[index] = np.multiply(frequency_grid, np.cos(np.radians(angle)))
|
215
214
|
|
216
215
|
return wedges
|
217
216
|
|
@@ -490,7 +489,11 @@ class WedgeReconstructed:
|
|
490
489
|
)
|
491
490
|
wedge_volume += plane_rotated * weights[index]
|
492
491
|
|
493
|
-
|
492
|
+
subset = center_slice(
|
493
|
+
wedge_volume.shape, (shape[opening_axis], shape[tilt_axis])
|
494
|
+
)
|
495
|
+
wedge_volume = wedge_volume[subset]
|
496
|
+
|
494
497
|
np.fmin(wedge_volume, np.max(weights), wedge_volume)
|
495
498
|
|
496
499
|
if opening_axis > tilt_axis:
|
tme/mask.py
ADDED
@@ -0,0 +1,341 @@
|
|
1
|
+
"""
|
2
|
+
Utility functions for generating template matching masks.
|
3
|
+
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
5
|
+
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
|
+
"""
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
from typing import Tuple, Optional
|
11
|
+
|
12
|
+
from .types import NDArray
|
13
|
+
from scipy.ndimage import gaussian_filter
|
14
|
+
from .matching_utils import rigid_transform
|
15
|
+
|
16
|
+
__all__ = ["elliptical_mask", "tube_mask", "box_mask", "membrane_mask"]
|
17
|
+
|
18
|
+
|
19
|
+
def elliptical_mask(
|
20
|
+
shape: Tuple[int],
|
21
|
+
radius: Tuple[float],
|
22
|
+
center: Optional[Tuple[float]] = None,
|
23
|
+
orientation: Optional[NDArray] = None,
|
24
|
+
sigma_decay: float = 0.0,
|
25
|
+
cutoff_sigma: float = 3,
|
26
|
+
**kwargs,
|
27
|
+
) -> NDArray:
|
28
|
+
"""
|
29
|
+
Creates an ellipsoidal mask.
|
30
|
+
|
31
|
+
Parameters
|
32
|
+
----------
|
33
|
+
shape : tuple of ints
|
34
|
+
Shape of the mask to be created.
|
35
|
+
radius : tuple of floats
|
36
|
+
Radius of the mask.
|
37
|
+
center : tuple of floats, optional
|
38
|
+
Center of the mask, default to shape // 2.
|
39
|
+
orientation : NDArray, optional.
|
40
|
+
Orientation of the mask as rotation matrix with shape (d,d).
|
41
|
+
|
42
|
+
Returns
|
43
|
+
-------
|
44
|
+
NDArray
|
45
|
+
The created ellipsoidal mask.
|
46
|
+
|
47
|
+
Raises
|
48
|
+
------
|
49
|
+
ValueError
|
50
|
+
If the length of center and radius is not one or the same as shape.
|
51
|
+
|
52
|
+
Examples
|
53
|
+
--------
|
54
|
+
>>> from tme.matching_utils import elliptical_mask
|
55
|
+
>>> mask = elliptical_mask(shape=(20,20), radius=(5,5), center=(10,10))
|
56
|
+
"""
|
57
|
+
shape, radius = np.asarray(shape), np.asarray(radius)
|
58
|
+
|
59
|
+
shape = shape.astype(int)
|
60
|
+
if center is None:
|
61
|
+
center = np.divide(shape, 2).astype(int)
|
62
|
+
|
63
|
+
center = np.asarray(center, dtype=np.float32)
|
64
|
+
radius = np.repeat(radius, shape.size // radius.size)
|
65
|
+
center = np.repeat(center, shape.size // center.size)
|
66
|
+
if radius.size != shape.size:
|
67
|
+
raise ValueError("Length of radius has to be either one or match shape.")
|
68
|
+
if center.size != shape.size:
|
69
|
+
raise ValueError("Length of center has to be either one or match shape.")
|
70
|
+
|
71
|
+
n = shape.size
|
72
|
+
center = center.reshape((-1,) + (1,) * n)
|
73
|
+
radius = radius.reshape((-1,) + (1,) * n)
|
74
|
+
|
75
|
+
indices = np.indices(shape, dtype=np.float32) - center
|
76
|
+
if orientation is not None:
|
77
|
+
return_shape = indices.shape
|
78
|
+
indices = indices.reshape(n, -1)
|
79
|
+
rigid_transform(
|
80
|
+
coordinates=indices,
|
81
|
+
rotation_matrix=np.asarray(orientation),
|
82
|
+
out=indices,
|
83
|
+
translation=np.zeros(n),
|
84
|
+
use_geometric_center=False,
|
85
|
+
)
|
86
|
+
indices = indices.reshape(*return_shape)
|
87
|
+
|
88
|
+
dist = np.linalg.norm(indices / radius, axis=0)
|
89
|
+
if sigma_decay > 0:
|
90
|
+
sigma_decay = 2 * (sigma_decay / np.mean(radius)) ** 2
|
91
|
+
mask = np.maximum(0, dist - 1)
|
92
|
+
mask = np.exp(-(mask**2) / sigma_decay)
|
93
|
+
mask *= mask > np.exp(-(cutoff_sigma**2) / 2)
|
94
|
+
else:
|
95
|
+
mask = (dist <= 1).astype(int)
|
96
|
+
return mask
|
97
|
+
|
98
|
+
|
99
|
+
def box_mask(
|
100
|
+
shape: Tuple[int],
|
101
|
+
center: Tuple[int],
|
102
|
+
size: Tuple[int],
|
103
|
+
sigma_decay: float = 0.0,
|
104
|
+
cutoff_sigma: float = 3.0,
|
105
|
+
**kwargs,
|
106
|
+
) -> np.ndarray:
|
107
|
+
"""
|
108
|
+
Creates a box mask centered around the provided center point.
|
109
|
+
|
110
|
+
Parameters
|
111
|
+
----------
|
112
|
+
shape : tuple of ints
|
113
|
+
Shape of the output array.
|
114
|
+
center : tuple of ints
|
115
|
+
Center point coordinates of the box.
|
116
|
+
size : tuple of ints
|
117
|
+
Side length of the box along each axis.
|
118
|
+
|
119
|
+
Returns
|
120
|
+
-------
|
121
|
+
NDArray
|
122
|
+
The created box mask.
|
123
|
+
|
124
|
+
Raises
|
125
|
+
------
|
126
|
+
ValueError
|
127
|
+
If ``shape`` and ``center`` do not have the same length.
|
128
|
+
If ``center`` and ``height`` do not have the same length.
|
129
|
+
"""
|
130
|
+
if len(shape) != len(center) or len(center) != len(size):
|
131
|
+
raise ValueError("The length of shape, center, and height must be consistent.")
|
132
|
+
|
133
|
+
shape = tuple(int(x) for x in shape)
|
134
|
+
center, size = np.array(center, dtype=int), np.array(size, dtype=int)
|
135
|
+
|
136
|
+
half_heights = size // 2
|
137
|
+
starts = np.maximum(center - half_heights, 0)
|
138
|
+
stops = np.minimum(center + half_heights + np.mod(size, 2) + 1, shape)
|
139
|
+
slice_indices = tuple(slice(*coord) for coord in zip(starts, stops))
|
140
|
+
|
141
|
+
out = np.zeros(shape)
|
142
|
+
out[slice_indices] = 1
|
143
|
+
|
144
|
+
if sigma_decay > 0:
|
145
|
+
mask_filter = gaussian_filter(
|
146
|
+
out.astype(np.float32), sigma=sigma_decay, truncate=cutoff_sigma
|
147
|
+
)
|
148
|
+
out = np.add(out, (1 - out) * mask_filter)
|
149
|
+
out *= out > np.exp(-(cutoff_sigma**2) / 2)
|
150
|
+
return out
|
151
|
+
|
152
|
+
|
153
|
+
def tube_mask(
|
154
|
+
shape: Tuple[int],
|
155
|
+
symmetry_axis: int,
|
156
|
+
center: Tuple[int],
|
157
|
+
inner_radius: float,
|
158
|
+
outer_radius: float,
|
159
|
+
height: int,
|
160
|
+
sigma_decay: float = 0.0,
|
161
|
+
cutoff_sigma: float = 3.0,
|
162
|
+
**kwargs,
|
163
|
+
) -> NDArray:
|
164
|
+
"""
|
165
|
+
Creates a tube mask.
|
166
|
+
|
167
|
+
Parameters
|
168
|
+
----------
|
169
|
+
shape : tuple
|
170
|
+
Shape of the mask to be created.
|
171
|
+
symmetry_axis : int
|
172
|
+
The axis of symmetry for the tube.
|
173
|
+
base_center : tuple
|
174
|
+
Center of the tube.
|
175
|
+
inner_radius : float
|
176
|
+
Inner radius of the tube.
|
177
|
+
outer_radius : float
|
178
|
+
Outer radius of the tube.
|
179
|
+
height : int
|
180
|
+
Height of the tube.
|
181
|
+
|
182
|
+
Returns
|
183
|
+
-------
|
184
|
+
NDArray
|
185
|
+
The created tube mask.
|
186
|
+
|
187
|
+
Raises
|
188
|
+
------
|
189
|
+
ValueError
|
190
|
+
If ``inner_radius`` is larger than ``outer_radius``.
|
191
|
+
If ``height`` is larger than the symmetry axis.
|
192
|
+
If ``base_center`` and ``shape`` do not have the same length.
|
193
|
+
"""
|
194
|
+
if inner_radius > outer_radius:
|
195
|
+
raise ValueError("inner_radius should be smaller than outer_radius.")
|
196
|
+
|
197
|
+
if height > shape[symmetry_axis]:
|
198
|
+
raise ValueError(f"Height can be no larger than {shape[symmetry_axis]}.")
|
199
|
+
|
200
|
+
if symmetry_axis > len(shape):
|
201
|
+
raise ValueError(f"symmetry_axis can be not larger than {len(shape)}.")
|
202
|
+
|
203
|
+
if len(center) != len(shape):
|
204
|
+
raise ValueError("shape and base_center need to have the same length.")
|
205
|
+
|
206
|
+
shape = tuple(int(x) for x in shape)
|
207
|
+
circle_shape = tuple(b for ix, b in enumerate(shape) if ix != symmetry_axis)
|
208
|
+
circle_center = tuple(b for ix, b in enumerate(center) if ix != symmetry_axis)
|
209
|
+
|
210
|
+
inner_circle = np.zeros(circle_shape)
|
211
|
+
outer_circle = np.zeros_like(inner_circle)
|
212
|
+
if inner_radius > 0:
|
213
|
+
inner_circle = elliptical_mask(
|
214
|
+
shape=circle_shape,
|
215
|
+
radius=inner_radius,
|
216
|
+
center=circle_center,
|
217
|
+
sigma_decay=sigma_decay,
|
218
|
+
cutoff_sigma=cutoff_sigma,
|
219
|
+
)
|
220
|
+
if outer_radius > 0:
|
221
|
+
outer_circle = elliptical_mask(
|
222
|
+
shape=circle_shape,
|
223
|
+
radius=outer_radius,
|
224
|
+
center=circle_center,
|
225
|
+
sigma_decay=sigma_decay,
|
226
|
+
cutoff_sigma=cutoff_sigma,
|
227
|
+
)
|
228
|
+
circle = outer_circle - inner_circle
|
229
|
+
circle = np.expand_dims(circle, axis=symmetry_axis)
|
230
|
+
|
231
|
+
center = center[symmetry_axis]
|
232
|
+
start_idx = int(center - height // 2)
|
233
|
+
stop_idx = int(center + height // 2 + height % 2)
|
234
|
+
|
235
|
+
start_idx, stop_idx = max(start_idx, 0), min(stop_idx, shape[symmetry_axis])
|
236
|
+
|
237
|
+
slice_indices = tuple(
|
238
|
+
slice(None) if i != symmetry_axis else slice(start_idx, stop_idx)
|
239
|
+
for i in range(len(shape))
|
240
|
+
)
|
241
|
+
tube = np.zeros(shape)
|
242
|
+
tube[slice_indices] = circle
|
243
|
+
|
244
|
+
return tube
|
245
|
+
|
246
|
+
|
247
|
+
def membrane_mask(
|
248
|
+
shape: Tuple[int],
|
249
|
+
radius: float,
|
250
|
+
thickness: float,
|
251
|
+
separation: float,
|
252
|
+
symmetry_axis: int = 2,
|
253
|
+
center: Optional[Tuple[float]] = None,
|
254
|
+
sigma_decay: float = 0.5,
|
255
|
+
cutoff_sigma: float = 3,
|
256
|
+
**kwargs,
|
257
|
+
) -> NDArray:
|
258
|
+
"""
|
259
|
+
Creates a membrane mask consisting of two parallel disks with Gaussian intensity profile.
|
260
|
+
Uses efficient broadcasting approach: flat disk mask × height profile.
|
261
|
+
|
262
|
+
Parameters
|
263
|
+
----------
|
264
|
+
shape : tuple of ints
|
265
|
+
Shape of the mask to be created.
|
266
|
+
radius : float
|
267
|
+
Radius of the membrane disks.
|
268
|
+
thickness : float
|
269
|
+
Thickness of each disk in the membrane.
|
270
|
+
separation : float
|
271
|
+
Distance between the centers of the two disks.
|
272
|
+
symmetry_axis : int, optional
|
273
|
+
The axis perpendicular to the membrane disks, defaults to 2.
|
274
|
+
center : tuple of floats, optional
|
275
|
+
Center of the membrane (midpoint between the two disks), defaults to shape // 2.
|
276
|
+
sigma_decay : float, optional
|
277
|
+
Controls edge sharpness relative to radius, defaults to 0.5.
|
278
|
+
cutoff_sigma : float, optional
|
279
|
+
Cutoff for height profile in standard deviations, defaults to 3.
|
280
|
+
|
281
|
+
Returns
|
282
|
+
-------
|
283
|
+
NDArray
|
284
|
+
The created membrane mask with Gaussian intensity profile.
|
285
|
+
|
286
|
+
Raises
|
287
|
+
------
|
288
|
+
ValueError
|
289
|
+
If ``thickness`` is negative.
|
290
|
+
If ``separation`` is negative.
|
291
|
+
If ``center`` and ``shape`` do not have the same length.
|
292
|
+
If ``symmetry_axis`` is out of bounds.
|
293
|
+
|
294
|
+
Examples
|
295
|
+
--------
|
296
|
+
>>> from tme.matching_utils import membrane_mask
|
297
|
+
>>> mask = membrane_mask(shape=(50,50,50), radius=10, thickness=2, separation=15)
|
298
|
+
"""
|
299
|
+
shape = np.asarray(shape, dtype=int)
|
300
|
+
|
301
|
+
if center is None:
|
302
|
+
center = np.divide(shape, 2).astype(float)
|
303
|
+
|
304
|
+
center = np.asarray(center, dtype=np.float32)
|
305
|
+
center = np.repeat(center, shape.size // center.size)
|
306
|
+
|
307
|
+
if thickness < 0:
|
308
|
+
raise ValueError("thickness must be non-negative.")
|
309
|
+
if separation < 0:
|
310
|
+
raise ValueError("separation must be non-negative.")
|
311
|
+
if symmetry_axis >= len(shape):
|
312
|
+
raise ValueError(f"symmetry_axis must be less than {len(shape)}.")
|
313
|
+
if center.size != shape.size:
|
314
|
+
raise ValueError("Length of center has to be either one or match shape.")
|
315
|
+
|
316
|
+
disk_mask = elliptical_mask(
|
317
|
+
shape=[x for i, x in enumerate(shape) if i != symmetry_axis],
|
318
|
+
radius=radius,
|
319
|
+
sigma_decay=sigma_decay,
|
320
|
+
cutoff_sigma=cutoff_sigma,
|
321
|
+
)
|
322
|
+
|
323
|
+
axial_coord = np.arange(shape[symmetry_axis]) - center[symmetry_axis]
|
324
|
+
height_profile = np.zeros((shape[symmetry_axis],), dtype=np.float32)
|
325
|
+
for leaflet_pos in [-separation / 2, separation / 2]:
|
326
|
+
leaflet_profile = np.exp(
|
327
|
+
-((axial_coord - leaflet_pos) ** 2) / (2 * (thickness / 3) ** 2)
|
328
|
+
)
|
329
|
+
cutoff_threshold = np.exp(-(cutoff_sigma**2) / 2)
|
330
|
+
leaflet_profile *= leaflet_profile > cutoff_threshold
|
331
|
+
|
332
|
+
height_profile = np.maximum(height_profile, leaflet_profile)
|
333
|
+
|
334
|
+
disk_mask = disk_mask.reshape(
|
335
|
+
[x if i != symmetry_axis else 1 for i, x in enumerate(shape)]
|
336
|
+
)
|
337
|
+
height_profile = height_profile.reshape(
|
338
|
+
[1 if i != symmetry_axis else x for i, x in enumerate(shape)]
|
339
|
+
)
|
340
|
+
|
341
|
+
return disk_mask * height_profile
|
tme/matching_data.py
CHANGED
@@ -128,11 +128,8 @@ class MatchingData:
|
|
128
128
|
slice_start = np.array([x.start for x in arr_slice], dtype=int)
|
129
129
|
slice_stop = np.array([x.stop for x in arr_slice], dtype=int)
|
130
130
|
|
131
|
-
|
132
|
-
|
133
|
-
# is defined from the perspective of the origin
|
134
|
-
right_pad = np.divide(padding, 2).astype(int)
|
135
|
-
left_pad = np.add(right_pad, np.mod(padding, 2))
|
131
|
+
left_pad = np.divide(padding, 2).astype(int)
|
132
|
+
right_pad = np.add(left_pad, np.mod(padding, 2))
|
136
133
|
|
137
134
|
data_voxels_left = np.minimum(slice_start, left_pad)
|
138
135
|
data_voxels_right = np.minimum(
|
@@ -253,7 +250,7 @@ class MatchingData:
|
|
253
250
|
target_offset[mask] = [x.start for x in target_slice]
|
254
251
|
mask = np.subtract(1, self._target_batch).astype(bool)
|
255
252
|
template_offset = np.zeros(len(self._output_template_shape), dtype=int)
|
256
|
-
template_offset[mask] = [x.start for x
|
253
|
+
template_offset[mask] = [x.start for x in template_slice]
|
257
254
|
|
258
255
|
translation_offset = tuple(x for x in target_offset)
|
259
256
|
|
@@ -299,7 +296,7 @@ class MatchingData:
|
|
299
296
|
|
300
297
|
def set_matching_dimension(self, target_dim: int = None, template_dim: int = None):
|
301
298
|
"""
|
302
|
-
Sets matching dimensions for target and template.
|
299
|
+
Sets matching batch dimensions for target and template.
|
303
300
|
|
304
301
|
Parameters
|
305
302
|
----------
|
@@ -490,8 +487,8 @@ class MatchingData:
|
|
490
487
|
def _fourier_padding(
|
491
488
|
target_shape: Tuple[int],
|
492
489
|
template_shape: Tuple[int],
|
493
|
-
pad_target: bool = False,
|
494
490
|
batch_mask: Tuple[int] = None,
|
491
|
+
**kwargs,
|
495
492
|
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
|
496
493
|
if batch_mask is None:
|
497
494
|
batch_mask = np.zeros_like(template_shape)
|
@@ -503,7 +500,7 @@ class MatchingData:
|
|
503
500
|
|
504
501
|
# Avoid padding batch dimensions
|
505
502
|
pad_shape = np.maximum(target_shape, template_shape)
|
506
|
-
pad_shape = np.maximum(
|
503
|
+
pad_shape = np.maximum(pad_shape, np.multiply(1 - batch_mask, pad_shape))
|
507
504
|
ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
|
508
505
|
conv_shape, fast_shape, fast_ft_shape = ret
|
509
506
|
|
@@ -525,21 +522,13 @@ class MatchingData:
|
|
525
522
|
shape_shift = np.multiply(np.add(shape_shift, offset), shape_mask)
|
526
523
|
fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
|
527
524
|
|
528
|
-
if pad_target:
|
529
|
-
fourier_shift = np.subtract(fourier_shift, np.subtract(1, template_mod))
|
530
|
-
|
531
525
|
fourier_shift = tuple(np.multiply(fourier_shift, 1 - batch_mask).astype(int))
|
532
526
|
return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
|
533
527
|
|
534
|
-
def fourier_padding(self,
|
528
|
+
def fourier_padding(self, **kwargs) -> Tuple:
|
535
529
|
"""
|
536
530
|
Computes efficient shape four Fourier transforms and potential associated shifts.
|
537
531
|
|
538
|
-
Parameters
|
539
|
-
----------
|
540
|
-
pad_target : bool, optional
|
541
|
-
Whether the target has been padded to the full convolution shape.
|
542
|
-
|
543
532
|
Returns
|
544
533
|
-------
|
545
534
|
Tuple[tuple of int, tuple of int, tuple of int, tuple of int]
|
@@ -553,7 +542,6 @@ class MatchingData:
|
|
553
542
|
target_shape=be.to_numpy_array(self._output_target_shape),
|
554
543
|
template_shape=be.to_numpy_array(self._output_template_shape),
|
555
544
|
batch_mask=be.to_numpy_array(self._batch_mask),
|
556
|
-
pad_target=pad_target,
|
557
545
|
)
|
558
546
|
|
559
547
|
def computation_schedule(
|