pytme 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.2.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/match_template.py +213 -196
- {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/postprocess.py +40 -78
- {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/preprocess.py +4 -5
- {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/preprocessor_gui.py +50 -103
- {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/pytme_runner.py +46 -69
- {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/METADATA +3 -2
- {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/RECORD +68 -65
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +213 -196
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +40 -78
- scripts/preprocess.py +4 -5
- scripts/preprocessor_gui.py +50 -103
- scripts/pytme_runner.py +46 -69
- scripts/refine_matches.py +5 -7
- tests/preprocessing/test_compose.py +31 -30
- tests/preprocessing/test_frequency_filters.py +17 -32
- tests/preprocessing/test_preprocessor.py +0 -19
- tests/preprocessing/test_utils.py +13 -1
- tests/test_analyzer.py +2 -10
- tests/test_backends.py +47 -18
- tests/test_density.py +72 -13
- tests/test_extensions.py +1 -0
- tests/test_matching_cli.py +23 -9
- tests/test_matching_exhaustive.py +5 -5
- tests/test_matching_utils.py +3 -3
- tests/test_rotations.py +13 -23
- tests/test_structure.py +1 -7
- tme/__version__.py +1 -1
- tme/analyzer/aggregation.py +47 -16
- tme/analyzer/base.py +34 -0
- tme/analyzer/peaks.py +26 -13
- tme/analyzer/proxy.py +14 -0
- tme/backends/_jax_utils.py +124 -71
- tme/backends/cupy_backend.py +6 -19
- tme/backends/jax_backend.py +110 -105
- tme/backends/matching_backend.py +0 -17
- tme/backends/mlx_backend.py +0 -29
- tme/backends/npfftw_backend.py +100 -97
- tme/backends/pytorch_backend.py +65 -78
- tme/cli.py +2 -2
- tme/density.py +102 -58
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/_utils.py +52 -24
- tme/filters/bandpass.py +99 -105
- tme/filters/compose.py +133 -39
- tme/filters/ctf.py +51 -102
- tme/filters/reconstruction.py +67 -122
- tme/filters/wedge.py +296 -325
- tme/filters/whitening.py +39 -75
- tme/mask.py +2 -2
- tme/matching_data.py +87 -15
- tme/matching_exhaustive.py +70 -120
- tme/matching_optimization.py +9 -63
- tme/matching_scores.py +261 -100
- tme/matching_utils.py +150 -91
- tme/memory.py +1 -0
- tme/orientations.py +28 -8
- tme/preprocessor.py +0 -239
- tme/rotations.py +102 -70
- tme/structure.py +601 -631
- tme/types.py +1 -0
- {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/WHEEL +0 -0
- {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/top_level.txt +0 -0
tme/filters/ctf.py
CHANGED
@@ -19,9 +19,7 @@ from ..parser import StarParser, XMLParser, MDOCParser
|
|
19
19
|
from ._utils import (
|
20
20
|
frequency_grid_at_angle,
|
21
21
|
compute_tilt_shape,
|
22
|
-
crop_real_fourier,
|
23
22
|
fftfreqn,
|
24
|
-
shift_fourier,
|
25
23
|
pad_to_length,
|
26
24
|
)
|
27
25
|
|
@@ -31,21 +29,17 @@ __all__ = ["CTF", "CTFReconstructed", "create_ctf"]
|
|
31
29
|
@dataclass
|
32
30
|
class CTF(ComposableFilter):
|
33
31
|
"""
|
34
|
-
Generate
|
32
|
+
Generate per-tilt contrast transfer function filter.
|
35
33
|
"""
|
36
34
|
|
37
|
-
#: The shape of the to-be created mask.
|
38
|
-
shape: Tuple[int] = None
|
39
35
|
#: The defocus in x direction (in units of sampling rate).
|
40
36
|
defocus_x: Tuple[float] = None
|
41
|
-
#: The tilt angles.
|
37
|
+
#: The tilt angles in degrees.
|
42
38
|
angles: Tuple[float] = None
|
43
39
|
#: The microscope projection axis, defaults to 2 (z).
|
44
40
|
opening_axis: int = 2
|
45
41
|
#: The axis along which the tilt is applied, defaults to 0 (x).
|
46
42
|
tilt_axis: int = 0
|
47
|
-
#: Whether to correct defocus gradient, defaults False.
|
48
|
-
correct_defocus_gradient: bool = False
|
49
43
|
#: The sampling rate, defaults to 1 Ångstrom / voxel.
|
50
44
|
sampling_rate: Tuple[float] = 1
|
51
45
|
#: The acceleration voltage in Volts, defaults to 300e3.
|
@@ -54,16 +48,14 @@ class CTF(ComposableFilter):
|
|
54
48
|
spherical_aberration: Tuple[float] = 2.7e7
|
55
49
|
#: The amplitude contrast, defaults to 0.07.
|
56
50
|
amplitude_contrast: Tuple[float] = 0.07
|
57
|
-
#: The phase shift in
|
51
|
+
#: The phase shift in radians, defaults to 0.
|
58
52
|
phase_shift: Tuple[float] = 0
|
59
|
-
#: The defocus angle in
|
53
|
+
#: The defocus angle in radians, defaults to 0.
|
60
54
|
defocus_angle: Tuple[float] = 0
|
61
55
|
#: The defocus value in y direction, defaults to None (in units of sampling rate).
|
62
56
|
defocus_y: Tuple[float] = None
|
63
57
|
#: Whether the returned CTF should be phase-flipped, defaults to True.
|
64
58
|
flip_phase: bool = True
|
65
|
-
#: Whether to return a ctf mask for rfft (for :py:class:`CTFReconstructed`).
|
66
|
-
return_real_fourier: bool = False
|
67
59
|
|
68
60
|
@classmethod
|
69
61
|
def from_file(cls, filename: str, **kwargs) -> "CTF":
|
@@ -99,7 +91,6 @@ class CTF(ComposableFilter):
|
|
99
91
|
|
100
92
|
# Pixel size needs to be overwritten by pixel size the ctf is generated for
|
101
93
|
init_kwargs = {
|
102
|
-
"shape": None,
|
103
94
|
"angles": data.get("angles", None),
|
104
95
|
"defocus_x": data["defocus_1"],
|
105
96
|
"sampling_rate": data["pixel_size"],
|
@@ -114,31 +105,17 @@ class CTF(ComposableFilter):
|
|
114
105
|
if k in init_kwargs and init_kwargs.get(k) is None:
|
115
106
|
init_kwargs[k] = v
|
116
107
|
init_kwargs = {k: v for k, v in init_kwargs.items() if v is not None}
|
117
|
-
return cls(**init_kwargs)
|
118
|
-
|
119
|
-
def __post_init__(self):
|
120
|
-
self.defocus_angle = np.radians(self.defocus_angle)
|
121
|
-
self.phase_shift = np.radians(self.phase_shift)
|
122
108
|
|
123
|
-
|
124
|
-
""
|
125
|
-
|
126
|
-
""
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
ret = self.weight(**func_args)
|
131
|
-
ret = be.astype(be.to_backend_array(ret), be._float_dtype)
|
132
|
-
return {
|
133
|
-
"data": ret,
|
134
|
-
"shape": func_args["shape"],
|
135
|
-
"return_real_fourier": func_args.get("return_real_fourier"),
|
136
|
-
"is_multiplicative_filter": True,
|
137
|
-
}
|
109
|
+
# Moved format conversion from __post__init
|
110
|
+
if "phase_shift" in init_kwargs:
|
111
|
+
init_kwargs["phase_shift"] = np.radians(init_kwargs["phase_shift"])
|
112
|
+
if "defocus_angle" in init_kwargs:
|
113
|
+
init_kwargs["defocus_angle"] = np.radians(init_kwargs["defocus_angle"])
|
114
|
+
return cls(**init_kwargs)
|
138
115
|
|
139
|
-
def
|
116
|
+
def _evaluate(
|
140
117
|
self,
|
141
|
-
shape: Tuple[int],
|
118
|
+
shape: Tuple[int, ...],
|
142
119
|
defocus_x: Tuple[float],
|
143
120
|
angles: Tuple[float],
|
144
121
|
opening_axis: int = 2,
|
@@ -147,15 +124,13 @@ class CTF(ComposableFilter):
|
|
147
124
|
phase_shift: Tuple[float] = 0,
|
148
125
|
defocus_angle: Tuple[float] = 0,
|
149
126
|
defocus_y: Tuple[float] = None,
|
150
|
-
correct_defocus_gradient: bool = False,
|
151
127
|
sampling_rate: Tuple[float] = 1,
|
152
128
|
acceleration_voltage: float = 300e3,
|
153
|
-
spherical_aberration: float = 2.
|
129
|
+
spherical_aberration: float = 2.7e7,
|
154
130
|
flip_phase: bool = True,
|
155
|
-
return_real_fourier: bool = False,
|
156
131
|
cutoff_frequency: float = 0.5,
|
157
132
|
**kwargs: Dict,
|
158
|
-
) ->
|
133
|
+
) -> Dict:
|
159
134
|
"""
|
160
135
|
Compute the CTF weight tilt stack.
|
161
136
|
|
@@ -164,38 +139,31 @@ class CTF(ComposableFilter):
|
|
164
139
|
shape : tuple of int
|
165
140
|
The shape of the CTF.
|
166
141
|
defocus_x : tuple of float
|
167
|
-
|
142
|
+
Defocus along the first principal axis in spatial units of sampling rate.
|
168
143
|
angles : tuple of float
|
169
|
-
The tilt angles.
|
144
|
+
The tilt angles in degrees.
|
170
145
|
opening_axis : int, optional
|
171
146
|
The axis around which the wedge is opened, defaults to 2.
|
172
147
|
tilt_axis : int, optional
|
173
148
|
The axis along which the tilt is applied, defaults to 0.
|
174
149
|
amplitude_contrast : tuple of float, optional
|
175
|
-
|
150
|
+
Amplitude contrast of microscope, defaults to 0.07.
|
176
151
|
phase_shift : tuple of float, optional
|
177
|
-
|
152
|
+
CTF phase shift in radians, defaults to 0.
|
178
153
|
defocus_angle : tuple of float, optional
|
179
|
-
|
154
|
+
Astigmatism angle in radians, defaults to 0.
|
180
155
|
defocus_y : tuple of float, optional
|
181
|
-
|
182
|
-
correct_defocus_gradient : bool, optional
|
183
|
-
Whether to correct defocus gradient, defaults to False.
|
156
|
+
Defocus along the second principal axis in spatial units of sampling rate.
|
184
157
|
sampling_rate : tuple of float, optional
|
185
158
|
The sampling rate, defaults to 1.
|
186
159
|
acceleration_voltage : float, optional
|
187
160
|
The acceleration voltage in electron microscopy, defaults to 300e3.
|
188
161
|
spherical_aberration : float, optional
|
189
|
-
|
162
|
+
Spherical aberration of microscope in units of sampling rate.
|
190
163
|
flip_phase : bool, optional
|
191
164
|
Whether the returned CTF should be phase-flipped, defaults to True.
|
192
165
|
**kwargs : Dict
|
193
166
|
Additional keyword arguments.
|
194
|
-
|
195
|
-
Returns
|
196
|
-
-------
|
197
|
-
NDArray
|
198
|
-
A stack containing the CTF weight.
|
199
167
|
"""
|
200
168
|
angles = np.atleast_1d(angles)
|
201
169
|
defoci_x = pad_to_length(defocus_x, angles.size)
|
@@ -219,7 +187,6 @@ class CTF(ComposableFilter):
|
|
219
187
|
corrected_tilt_axis -= 1
|
220
188
|
|
221
189
|
for index, angle in enumerate(angles):
|
222
|
-
correction = correct_defocus_gradient and angle is not None
|
223
190
|
chi = create_ctf(
|
224
191
|
angle=angle,
|
225
192
|
shape=ctf_shape,
|
@@ -227,7 +194,6 @@ class CTF(ComposableFilter):
|
|
227
194
|
defocus_y=defoci_y[index],
|
228
195
|
sampling_rate=sampling_rate,
|
229
196
|
acceleration_voltage=acceleration_voltage[index],
|
230
|
-
correct_defocus_gradient=correction,
|
231
197
|
spherical_aberration=spherical_aberration[index],
|
232
198
|
cutoff_frequency=cutoff_frequency,
|
233
199
|
phase_shift=phase_shift[index],
|
@@ -244,16 +210,16 @@ class CTF(ComposableFilter):
|
|
244
210
|
stack = np.negative(stack, out=stack)
|
245
211
|
if flip_phase:
|
246
212
|
stack = np.abs(stack, out=stack)
|
247
|
-
return be.to_backend_array(
|
213
|
+
return {"data": be.to_backend_array(stack), "shape": shape}
|
248
214
|
|
249
215
|
|
216
|
+
@dataclass
|
250
217
|
class CTFReconstructed(CTF):
|
251
218
|
"""
|
252
|
-
|
253
|
-
per-tilt parameters like in :py:class:`CTF`.
|
219
|
+
Generate CTF filter for reconstructions.
|
254
220
|
"""
|
255
221
|
|
256
|
-
def
|
222
|
+
def _evaluate(
|
257
223
|
self,
|
258
224
|
shape: Tuple[int],
|
259
225
|
defocus_x: Tuple[float],
|
@@ -265,10 +231,9 @@ class CTFReconstructed(CTF):
|
|
265
231
|
acceleration_voltage: float = 300e3,
|
266
232
|
spherical_aberration: float = 2.7e3,
|
267
233
|
flip_phase: bool = True,
|
268
|
-
return_real_fourier: bool = False,
|
269
234
|
cutoff_frequency: float = 0.5,
|
270
235
|
**kwargs: Dict,
|
271
|
-
) ->
|
236
|
+
) -> Dict:
|
272
237
|
"""
|
273
238
|
Compute the CTF weight tilt stack.
|
274
239
|
|
@@ -277,17 +242,17 @@ class CTFReconstructed(CTF):
|
|
277
242
|
shape : tuple of int
|
278
243
|
The shape of the CTF.
|
279
244
|
defocus_x : tuple of float
|
280
|
-
|
245
|
+
Defocus along the first principal axis in spatial units of sampling rate.
|
281
246
|
opening_axis : int, optional
|
282
247
|
The axis around which the wedge is opened, defaults to 2.
|
283
248
|
amplitude_contrast : float, optional
|
284
249
|
The amplitude contrast, defaults to 0.07.
|
285
250
|
phase_shift : tuple of float, optional
|
286
|
-
|
251
|
+
CTF phase shift in radians, defaults to 0.
|
287
252
|
defocus_angle : tuple of float, optional
|
288
253
|
The defocus angle in radians, defaults to 0.
|
289
254
|
defocus_y : tuple of float, optional
|
290
|
-
|
255
|
+
Defocus along the second principal axis in spatial units of sampling rate.
|
291
256
|
sampling_rate : tuple of float, optional
|
292
257
|
The sampling rate, defaults to 1.
|
293
258
|
acceleration_voltage : float, optional
|
@@ -310,7 +275,6 @@ class CTFReconstructed(CTF):
|
|
310
275
|
defocus_y=defocus_y,
|
311
276
|
sampling_rate=np.max(sampling_rate),
|
312
277
|
acceleration_voltage=self.acceleration_voltage,
|
313
|
-
correct_defocus_gradient=False,
|
314
278
|
spherical_aberration=spherical_aberration,
|
315
279
|
cutoff_frequency=cutoff_frequency,
|
316
280
|
phase_shift=phase_shift,
|
@@ -318,14 +282,10 @@ class CTFReconstructed(CTF):
|
|
318
282
|
amplitude_contrast=amplitude_contrast,
|
319
283
|
)
|
320
284
|
# Avoid contrast inversion
|
321
|
-
np.negative(stack, out=stack)
|
285
|
+
stack = np.negative(stack, out=stack)
|
322
286
|
if flip_phase:
|
323
|
-
np.abs(stack, out=stack)
|
324
|
-
|
325
|
-
stack = shift_fourier(data=stack, shape_is_real_fourier=False)
|
326
|
-
if return_real_fourier:
|
327
|
-
stack = crop_real_fourier(stack)
|
328
|
-
return be.to_backend_array(np.squeeze(stack))
|
287
|
+
stack = np.abs(stack, out=stack)
|
288
|
+
return {"data": be.to_backend_array(stack), "shape": shape}
|
329
289
|
|
330
290
|
|
331
291
|
def _from_xml(filename: str) -> Dict:
|
@@ -501,7 +461,7 @@ def _from_mdoc(filename: str) -> Dict:
|
|
501
461
|
return output
|
502
462
|
|
503
463
|
|
504
|
-
def _compute_electron_wavelength(acceleration_voltage: int =
|
464
|
+
def _compute_electron_wavelength(acceleration_voltage: int = 300e3):
|
505
465
|
"""Computes the wavelength of an electron in angstrom."""
|
506
466
|
|
507
467
|
# Physical constants expressed in SI units
|
@@ -524,14 +484,13 @@ def _compute_electron_wavelength(acceleration_voltage: int = None):
|
|
524
484
|
def create_ctf(
|
525
485
|
shape: Tuple[int],
|
526
486
|
defocus_x: float,
|
527
|
-
acceleration_voltage: float =
|
487
|
+
acceleration_voltage: float = 300e3,
|
528
488
|
defocus_angle: float = 0,
|
529
489
|
phase_shift: float = 0,
|
530
490
|
defocus_y: float = None,
|
531
491
|
sampling_rate: float = 1,
|
532
492
|
spherical_aberration: float = 2.7e7,
|
533
493
|
amplitude_contrast: float = 0.07,
|
534
|
-
correct_defocus_gradient: bool = False,
|
535
494
|
cutoff_frequency: float = 0.5,
|
536
495
|
angle: float = None,
|
537
496
|
tilt_axis: int = 0,
|
@@ -546,15 +505,16 @@ def create_ctf(
|
|
546
505
|
shape : Tuple[int]
|
547
506
|
Shape of the returned CTF mask.
|
548
507
|
defocus_x : float
|
549
|
-
Defocus
|
508
|
+
Defocus along the first principal axis in spatial units of sampling rate,
|
509
|
+
e.g. 30000 Angstrom.
|
550
510
|
acceleration_voltage : float, optional
|
551
|
-
Acceleration voltage in keV, defaults to
|
511
|
+
Acceleration voltage in keV, defaults to 300e3.
|
552
512
|
defocus_angle : float, optional
|
553
|
-
Astigmatism in radians, defaults to 0.
|
513
|
+
Astigmatism angle in radians, defaults to 0.
|
554
514
|
phase_shift : float, optional
|
555
|
-
|
515
|
+
CTF phase shift in radians, defaults to 0.
|
556
516
|
defocus_y : float, optional
|
557
|
-
Defocus
|
517
|
+
Defocus along the second principal axis in spatial units of sampling rate.
|
558
518
|
tilt_axis : int, optional
|
559
519
|
Axes the specimen was tilted over, defaults to 0 (x-axis).
|
560
520
|
sampling_rate : float or tuple of floats
|
@@ -564,7 +524,7 @@ def create_ctf(
|
|
564
524
|
spherical_aberration : float, optional
|
565
525
|
Spherical aberration of microscope in units of sampling rate.
|
566
526
|
angle : float, optional
|
567
|
-
Assume the created CTF is a projection
|
527
|
+
Assume the created CTF is a projection observed at angle degrees.
|
568
528
|
opening_axis : int, optional
|
569
529
|
Projection axis, only relevant if angle is given.
|
570
530
|
full_shape : tuple of ints
|
@@ -589,31 +549,18 @@ def create_ctf(
|
|
589
549
|
|
590
550
|
defocus_x = defocus_x / sampling_rate if defocus_x is not None else None
|
591
551
|
defocus_y = defocus_y / sampling_rate if defocus_y is not None else None
|
592
|
-
if
|
552
|
+
if defocus_y is not None:
|
593
553
|
if len(shape) < 2:
|
594
554
|
raise ValueError(f"Length of shape needs to be at least 2, got {shape}")
|
595
555
|
|
596
|
-
# Axial distance from grid center in
|
597
|
-
sampling = tuple(float(x) for x in np.divide(sampling_rate, shape))
|
556
|
+
# Axial distance from grid center in voxels
|
598
557
|
grid = fftfreqn(
|
599
558
|
shape=shape,
|
600
|
-
sampling_rate=
|
559
|
+
sampling_rate=None,
|
601
560
|
return_sparse_grid=True,
|
561
|
+
fftshift=False,
|
602
562
|
)
|
603
563
|
|
604
|
-
# This should be done after defocus_x computation
|
605
|
-
if correct_defocus_gradient:
|
606
|
-
if angle is None:
|
607
|
-
raise ValueError("Cannot correct for defocus gradient without angle.")
|
608
|
-
|
609
|
-
angle_rad = np.radians(angle)
|
610
|
-
defocus_gradient = np.multiply(grid[tilt_axis], np.sin(angle_rad))
|
611
|
-
|
612
|
-
if tilt_axis == 0:
|
613
|
-
defocus_x = np.add(defocus_x, defocus_gradient)
|
614
|
-
elif tilt_axis == 1 and defocus_y is not None:
|
615
|
-
defocus_y = np.add(defocus_y, defocus_gradient)
|
616
|
-
|
617
564
|
# 0.5 * (dx + dy) + cos(2 * (azimuth - astigmatism) * (dx - dy))
|
618
565
|
if defocus_y is not None:
|
619
566
|
defocus_sum = np.add(defocus_x, defocus_y)
|
@@ -628,7 +575,9 @@ def create_ctf(
|
|
628
575
|
defocus_x = np.add(defocus_sum, defocus_difference)
|
629
576
|
defocus_x *= 0.5
|
630
577
|
|
631
|
-
frequency_grid = fftfreqn(
|
578
|
+
frequency_grid = fftfreqn(
|
579
|
+
shape, sampling_rate=1, compute_euclidean_norm=True, fftshift=False
|
580
|
+
)
|
632
581
|
if angle is not None and opening_axis is not None and full_shape is not None:
|
633
582
|
frequency_grid = frequency_grid_at_angle(
|
634
583
|
shape=full_shape,
|
@@ -636,8 +585,9 @@ def create_ctf(
|
|
636
585
|
opening_axis=opening_axis,
|
637
586
|
angle=angle,
|
638
587
|
sampling_rate=1,
|
588
|
+
fftshift=False,
|
639
589
|
)
|
640
|
-
frequency_mask = frequency_grid
|
590
|
+
frequency_mask = frequency_grid <= cutoff_frequency
|
641
591
|
|
642
592
|
# k^2*π*λ(dx - 0.5 * sph_abb * λ^2 * k^2) + phase_shift + ampl_contrast_term)
|
643
593
|
frequency_grid = np.square(frequency_grid, out=frequency_grid)
|
@@ -652,5 +602,4 @@ def create_ctf(
|
|
652
602
|
)
|
653
603
|
)
|
654
604
|
chi = np.sin(-chi, out=chi)
|
655
|
-
|
656
|
-
return chi
|
605
|
+
return np.multiply(chi, frequency_mask, out=chi)
|
tme/filters/reconstruction.py
CHANGED
@@ -1,22 +1,22 @@
|
|
1
1
|
"""
|
2
|
-
|
2
|
+
Implements class ReconstructFromTilt and ShiftFourier.
|
3
3
|
|
4
4
|
Copyright (c) 2024 European Molecular Biology Laboratory
|
5
5
|
|
6
6
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
7
|
"""
|
8
8
|
|
9
|
-
from typing import Tuple
|
9
|
+
from typing import Tuple, Dict
|
10
10
|
from dataclasses import dataclass
|
11
11
|
|
12
12
|
import numpy as np
|
13
13
|
|
14
|
+
from ..types import BackendArray
|
14
15
|
from ..backends import backend as be
|
15
|
-
from ..types import NDArray, BackendArray
|
16
16
|
|
17
17
|
from .compose import ComposableFilter
|
18
18
|
from ..rotations import euler_to_rotationmatrix
|
19
|
-
from ._utils import
|
19
|
+
from ._utils import shift_fourier, create_reconstruction_filter
|
20
20
|
|
21
21
|
__all__ = ["ReconstructFromTilt", "ShiftFourier"]
|
22
22
|
|
@@ -24,45 +24,56 @@ __all__ = ["ReconstructFromTilt", "ShiftFourier"]
|
|
24
24
|
@dataclass
|
25
25
|
class ReconstructFromTilt(ComposableFilter):
|
26
26
|
"""
|
27
|
-
|
28
|
-
backprojection
|
27
|
+
Place Fourier transforms of d-dimensional inputs into a d+1-dimensional array
|
28
|
+
aking of weighted backprojection using direct fourier inversion.
|
29
|
+
|
30
|
+
This class is used to reconstruct the output of ComposableFilter instances for
|
31
|
+
individual tilts to be applied to query templates.
|
32
|
+
|
33
|
+
See Also
|
34
|
+
--------
|
35
|
+
:py:class:`tme.filters.CTF`
|
36
|
+
:py:class:`tme.filters.Wedge`
|
37
|
+
:py:class:`tme.filters.BandPass`
|
38
|
+
|
29
39
|
"""
|
30
40
|
|
31
|
-
#:
|
32
|
-
shape: Tuple[int] = None
|
33
|
-
#: Angle of each individual tilt.
|
41
|
+
#: Angle of each individual tilt in degrees.
|
34
42
|
angles: Tuple[float] = None
|
35
43
|
#: Projection axis, defaults to 2 (z).
|
36
44
|
opening_axis: int = 2
|
37
45
|
#: Tilt axis, defaults to 0 (x).
|
38
46
|
tilt_axis: int = 0
|
39
|
-
#: Whether to return a share compliant with rfftn.
|
40
|
-
return_real_fourier: bool = True
|
41
47
|
#: Interpolation order used for rotation
|
42
48
|
interpolation_order: int = 1
|
43
49
|
#: Filter window applied during reconstruction.
|
44
50
|
reconstruction_filter: str = None
|
45
51
|
|
46
|
-
|
52
|
+
@staticmethod
|
53
|
+
def _evaluate(
|
54
|
+
data: BackendArray,
|
55
|
+
shape: Tuple[int, ...],
|
56
|
+
angles: Tuple[float],
|
57
|
+
opening_axis: int = 2,
|
58
|
+
tilt_axis: int = 0,
|
59
|
+
interpolation_order: int = 1,
|
60
|
+
reconstruction_filter: str = None,
|
61
|
+
**kwargs,
|
62
|
+
) -> Dict:
|
47
63
|
"""
|
48
|
-
Reconstruct a
|
64
|
+
Reconstruct a 3-dimensional array from n 2-dimensional inputs using WBP.
|
49
65
|
|
50
66
|
Parameters
|
51
67
|
----------
|
52
|
-
shape : tuple of int
|
53
|
-
The shape of the reconstruction volume.
|
54
68
|
data : BackendArray
|
55
69
|
D-dimensional image stack with shape (n, ...). The data is assumed to be
|
56
|
-
|
57
|
-
DC component
|
70
|
+
the Fourier transform of the stack you are trying to reconstruct with
|
71
|
+
DC component at the origin. Notably, the data needs to be the output of
|
72
|
+
np.fft.fftn not the reduced np.fft.rffn.
|
73
|
+
shape : tuple of int
|
74
|
+
The shape of the reconstruction volume.
|
58
75
|
angles : tuple of float
|
59
|
-
Angle
|
60
|
-
return_real_fourier : bool, optional
|
61
|
-
Return a shape compliant
|
62
|
-
return_real_fourier : tuple of int
|
63
|
-
Return a shape compliant with rfft, i.e., omit the negative frequencies
|
64
|
-
terms resulting in a return shape (*shape[:-1], shape[-1]//2+1). Defaults
|
65
|
-
to False.
|
76
|
+
Angle to place individual slices at in degrees.
|
66
77
|
reconstruction_filter : bool, optional
|
67
78
|
Filter window applied during reconstruction.
|
68
79
|
See :py:meth:`create_reconstruction_filter` for available options.
|
@@ -70,80 +81,21 @@ class ReconstructFromTilt(ComposableFilter):
|
|
70
81
|
Axis the plane is tilted over, defaults to 0 (x).
|
71
82
|
opening_axis : int
|
72
83
|
The projection axis, defaults to 2 (z).
|
73
|
-
|
74
|
-
Returns
|
75
|
-
-------
|
76
|
-
dict
|
77
|
-
data: BackendArray
|
78
|
-
The filter mask.
|
79
|
-
shape: tuple of ints
|
80
|
-
The requested filter shape
|
81
|
-
return_real_fourier: bool
|
82
|
-
Whether data is compliant with rfftn.
|
83
|
-
is_multiplicative_filter: bool
|
84
|
-
Whether the filter is multiplicative in Fourier space.
|
85
|
-
"""
|
86
|
-
|
87
|
-
func_args = vars(self).copy()
|
88
|
-
func_args.update(kwargs)
|
89
|
-
|
90
|
-
ret = self.reconstruct(**func_args)
|
91
|
-
|
92
|
-
ret = shift_fourier(data=ret, shape_is_real_fourier=False)
|
93
|
-
if return_real_fourier:
|
94
|
-
ret = crop_real_fourier(ret)
|
95
|
-
|
96
|
-
return {
|
97
|
-
"data": ret,
|
98
|
-
"shape": func_args["shape"],
|
99
|
-
"return_real_fourier": return_real_fourier,
|
100
|
-
"is_multiplicative_filter": False,
|
101
|
-
}
|
102
|
-
|
103
|
-
@staticmethod
|
104
|
-
def reconstruct(
|
105
|
-
data: NDArray,
|
106
|
-
shape: Tuple[int],
|
107
|
-
angles: Tuple[float],
|
108
|
-
opening_axis: int,
|
109
|
-
tilt_axis: int,
|
110
|
-
interpolation_order: int = 1,
|
111
|
-
reconstruction_filter: str = None,
|
112
|
-
**kwargs,
|
113
|
-
):
|
114
84
|
"""
|
115
|
-
Reconstruct a volume from a tilt series.
|
116
|
-
|
117
|
-
Parameters
|
118
|
-
----------
|
119
|
-
data : NDArray
|
120
|
-
The Fourier transform of tilt series data.
|
121
|
-
shape : tuple of int
|
122
|
-
Shape of the reconstruction.
|
123
|
-
angles : tuple of float
|
124
|
-
Angle of each individual tilt.
|
125
|
-
opening_axis : int
|
126
|
-
The axis around which the volume is opened.
|
127
|
-
tilt_axis : int
|
128
|
-
Axis the plane is tilted over.
|
129
|
-
interpolation_order : int, optional
|
130
|
-
Interpolation order used for rotation, defaults to 1.
|
131
|
-
reconstruction_filter : bool, optional
|
132
|
-
Filter window applied during reconstruction.
|
133
|
-
See :py:meth:`create_reconstruction_filter` for available options.
|
134
85
|
|
135
|
-
Returns
|
136
|
-
-------
|
137
|
-
NDArray
|
138
|
-
The reconstructed volume.
|
139
|
-
"""
|
140
86
|
if data.shape == shape:
|
141
87
|
return data
|
142
88
|
|
143
|
-
|
89
|
+
# Composable filters use frequency grids centered at the origin
|
90
|
+
# Here we require them to be centered at subset.shape // 2
|
91
|
+
for i in range(data.shape[0]):
|
92
|
+
data_shifted = shift_fourier(
|
93
|
+
data[i], shape_is_real_fourier=False, ifftshift=False
|
94
|
+
)
|
95
|
+
data = be.at(data, i, data_shifted)
|
96
|
+
|
144
97
|
volume_temp = be.zeros(shape, dtype=data.dtype)
|
145
|
-
|
146
|
-
volume = be.zeros(shape, dtype=data.dtype)
|
98
|
+
rec = be.zeros(shape, dtype=data.dtype)
|
147
99
|
|
148
100
|
slices = tuple(slice(a // 2, (a // 2) + 1) for a in shape)
|
149
101
|
subset = tuple(
|
@@ -162,58 +114,51 @@ class ReconstructFromTilt(ComposableFilter):
|
|
162
114
|
filter_type=reconstruction_filter,
|
163
115
|
filter_shape=(shape[tilt_axis],),
|
164
116
|
tilt_angles=angles,
|
117
|
+
fftshift=True,
|
165
118
|
)
|
166
119
|
rec_shape = tuple(1 if i != tilt_axis else x for i, x in enumerate(shape))
|
167
120
|
rec_filter = be.to_backend_array(rec_filter)
|
168
121
|
rec_filter = be.reshape(rec_filter, rec_shape)
|
169
122
|
|
170
123
|
angles = be.to_backend_array(angles)
|
124
|
+
axis_index = min(
|
125
|
+
tuple(i for i in range(len(shape)) if i not in (tilt_axis, opening_axis))
|
126
|
+
)
|
171
127
|
for index in range(len(angles)):
|
172
128
|
angles_loop = be.fill(angles_loop, 0)
|
173
129
|
volume_temp = be.fill(volume_temp, 0)
|
174
|
-
volume_temp_rotated = be.fill(volume_temp_rotated, 0)
|
175
130
|
|
176
131
|
# Jax compatibility
|
177
132
|
volume_temp = be.at(volume_temp, subset, wedges[index] * rec_filter)
|
178
|
-
angles_loop = be.at(angles_loop,
|
133
|
+
angles_loop = be.at(angles_loop, axis_index, angles[index])
|
179
134
|
|
180
|
-
|
181
|
-
rotation_matrix = euler_to_rotationmatrix(
|
182
|
-
|
135
|
+
# We want a push rotation but rigid transform assumes pull
|
136
|
+
rotation_matrix = euler_to_rotationmatrix(
|
137
|
+
be.to_numpy_array(angles_loop), seq="xyz"
|
138
|
+
).T
|
183
139
|
|
184
|
-
|
140
|
+
volume_temp, _ = be.rigid_transform(
|
185
141
|
arr=volume_temp,
|
186
|
-
rotation_matrix=rotation_matrix,
|
187
|
-
out=volume_temp_rotated,
|
142
|
+
rotation_matrix=be.to_backend_array(rotation_matrix),
|
188
143
|
use_geometric_center=True,
|
189
144
|
order=interpolation_order,
|
190
145
|
)
|
191
|
-
|
146
|
+
rec = be.add(rec, volume_temp, out=rec)
|
192
147
|
|
193
|
-
|
148
|
+
# Shift DC component back to origin
|
149
|
+
rec = shift_fourier(rec, shape_is_real_fourier=False, ifftshift=True)
|
150
|
+
return {"data": rec, "shape": shape, "is_multiplicative_filter": False}
|
194
151
|
|
195
152
|
|
196
153
|
class ShiftFourier(ComposableFilter):
|
197
|
-
def
|
198
|
-
self,
|
199
|
-
data: BackendArray,
|
200
|
-
shape_is_real_fourier: bool = False,
|
201
|
-
return_real_fourier: bool = True,
|
202
|
-
**kwargs,
|
203
|
-
):
|
154
|
+
def _evaluate(self, shape: Tuple[int, ...], data: BackendArray, **kwargs) -> Dict:
|
204
155
|
ret = []
|
205
156
|
for index in range(data.shape[0]):
|
206
|
-
mask =
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
mask = crop_real_fourier(mask)
|
157
|
+
mask = shift_fourier(
|
158
|
+
data=data[index],
|
159
|
+
shape_is_real_fourier=kwargs.get("return_real_fourier", False),
|
160
|
+
)
|
211
161
|
ret.append(mask[None])
|
212
|
-
|
213
|
-
|
214
|
-
return {
|
215
|
-
"data": ret,
|
216
|
-
"shape": kwargs.get("shape"),
|
217
|
-
"return_real_fourier": return_real_fourier,
|
218
|
-
"is_multiplicative_filter": False,
|
219
|
-
}
|
162
|
+
|
163
|
+
ret = be.concatenate(ret, axis=0)
|
164
|
+
return {"data": ret, "shape": shape, "is_multiplicative_filter": False}
|