pytme 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2.dev0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pytme-0.3.2.dev0.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/match_template.py +213 -196
  3. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/postprocess.py +40 -78
  4. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/preprocess.py +4 -5
  5. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/preprocessor_gui.py +49 -103
  6. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/pytme_runner.py +46 -69
  7. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/METADATA +2 -1
  8. pytme-0.3.2.dev0.dist-info/RECORD +136 -0
  9. scripts/estimate_ram_usage.py +97 -0
  10. scripts/match_template.py +213 -196
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +40 -78
  13. scripts/preprocess.py +4 -5
  14. scripts/preprocessor_gui.py +49 -103
  15. scripts/pytme_runner.py +46 -69
  16. tests/preprocessing/test_compose.py +31 -30
  17. tests/preprocessing/test_frequency_filters.py +17 -32
  18. tests/preprocessing/test_preprocessor.py +0 -19
  19. tests/preprocessing/test_utils.py +13 -1
  20. tests/test_analyzer.py +2 -10
  21. tests/test_backends.py +47 -18
  22. tests/test_density.py +72 -13
  23. tests/test_extensions.py +1 -0
  24. tests/test_matching_cli.py +23 -9
  25. tests/test_matching_exhaustive.py +5 -5
  26. tests/test_matching_utils.py +3 -3
  27. tests/test_orientations.py +12 -0
  28. tests/test_rotations.py +13 -23
  29. tests/test_structure.py +1 -7
  30. tme/__version__.py +1 -1
  31. tme/analyzer/aggregation.py +47 -16
  32. tme/analyzer/base.py +34 -0
  33. tme/analyzer/peaks.py +26 -13
  34. tme/analyzer/proxy.py +14 -0
  35. tme/backends/_jax_utils.py +91 -68
  36. tme/backends/cupy_backend.py +6 -19
  37. tme/backends/jax_backend.py +103 -98
  38. tme/backends/matching_backend.py +0 -17
  39. tme/backends/mlx_backend.py +0 -29
  40. tme/backends/npfftw_backend.py +100 -97
  41. tme/backends/pytorch_backend.py +65 -78
  42. tme/cli.py +2 -2
  43. tme/density.py +44 -57
  44. tme/extensions.cpython-311-darwin.so +0 -0
  45. tme/filters/_utils.py +52 -24
  46. tme/filters/bandpass.py +99 -105
  47. tme/filters/compose.py +133 -39
  48. tme/filters/ctf.py +51 -102
  49. tme/filters/reconstruction.py +67 -122
  50. tme/filters/wedge.py +296 -325
  51. tme/filters/whitening.py +39 -75
  52. tme/mask.py +2 -2
  53. tme/matching_data.py +87 -15
  54. tme/matching_exhaustive.py +70 -120
  55. tme/matching_optimization.py +9 -63
  56. tme/matching_scores.py +261 -100
  57. tme/matching_utils.py +150 -91
  58. tme/memory.py +1 -0
  59. tme/orientations.py +17 -3
  60. tme/preprocessor.py +0 -239
  61. tme/rotations.py +102 -70
  62. tme/structure.py +601 -631
  63. tme/types.py +1 -0
  64. pytme-0.3.1.post2.dist-info/RECORD +0 -133
  65. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/estimate_memory_usage.py +0 -0
  66. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/WHEEL +0 -0
  67. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/entry_points.txt +0 -0
  68. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/licenses/LICENSE +0 -0
  69. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/top_level.txt +0 -0
tme/filters/ctf.py CHANGED
@@ -19,9 +19,7 @@ from ..parser import StarParser, XMLParser, MDOCParser
19
19
  from ._utils import (
20
20
  frequency_grid_at_angle,
21
21
  compute_tilt_shape,
22
- crop_real_fourier,
23
22
  fftfreqn,
24
- shift_fourier,
25
23
  pad_to_length,
26
24
  )
27
25
 
@@ -31,21 +29,17 @@ __all__ = ["CTF", "CTFReconstructed", "create_ctf"]
31
29
  @dataclass
32
30
  class CTF(ComposableFilter):
33
31
  """
34
- Generate a per-tilt contrast transfer function mask.
32
+ Generate per-tilt contrast transfer function filter.
35
33
  """
36
34
 
37
- #: The shape of the to-be created mask.
38
- shape: Tuple[int] = None
39
35
  #: The defocus in x direction (in units of sampling rate).
40
36
  defocus_x: Tuple[float] = None
41
- #: The tilt angles.
37
+ #: The tilt angles in degrees.
42
38
  angles: Tuple[float] = None
43
39
  #: The microscope projection axis, defaults to 2 (z).
44
40
  opening_axis: int = 2
45
41
  #: The axis along which the tilt is applied, defaults to 0 (x).
46
42
  tilt_axis: int = 0
47
- #: Whether to correct defocus gradient, defaults False.
48
- correct_defocus_gradient: bool = False
49
43
  #: The sampling rate, defaults to 1 Ångstrom / voxel.
50
44
  sampling_rate: Tuple[float] = 1
51
45
  #: The acceleration voltage in Volts, defaults to 300e3.
@@ -54,16 +48,14 @@ class CTF(ComposableFilter):
54
48
  spherical_aberration: Tuple[float] = 2.7e7
55
49
  #: The amplitude contrast, defaults to 0.07.
56
50
  amplitude_contrast: Tuple[float] = 0.07
57
- #: The phase shift in degrees, defaults to 0.
51
+ #: The phase shift in radians, defaults to 0.
58
52
  phase_shift: Tuple[float] = 0
59
- #: The defocus angle in degrees, defaults to 0.
53
+ #: The defocus angle in radians, defaults to 0.
60
54
  defocus_angle: Tuple[float] = 0
61
55
  #: The defocus value in y direction, defaults to None (in units of sampling rate).
62
56
  defocus_y: Tuple[float] = None
63
57
  #: Whether the returned CTF should be phase-flipped, defaults to True.
64
58
  flip_phase: bool = True
65
- #: Whether to return a ctf mask for rfft (for :py:class:`CTFReconstructed`).
66
- return_real_fourier: bool = False
67
59
 
68
60
  @classmethod
69
61
  def from_file(cls, filename: str, **kwargs) -> "CTF":
@@ -99,7 +91,6 @@ class CTF(ComposableFilter):
99
91
 
100
92
  # Pixel size needs to be overwritten by pixel size the ctf is generated for
101
93
  init_kwargs = {
102
- "shape": None,
103
94
  "angles": data.get("angles", None),
104
95
  "defocus_x": data["defocus_1"],
105
96
  "sampling_rate": data["pixel_size"],
@@ -114,31 +105,17 @@ class CTF(ComposableFilter):
114
105
  if k in init_kwargs and init_kwargs.get(k) is None:
115
106
  init_kwargs[k] = v
116
107
  init_kwargs = {k: v for k, v in init_kwargs.items() if v is not None}
117
- return cls(**init_kwargs)
118
-
119
- def __post_init__(self):
120
- self.defocus_angle = np.radians(self.defocus_angle)
121
- self.phase_shift = np.radians(self.phase_shift)
122
108
 
123
- def __call__(self, **kwargs) -> NDArray:
124
- """
125
- Returns a CTF stack of chosen parameters with DC component in the center.
126
- """
127
- func_args = vars(self).copy()
128
- func_args.update(kwargs)
129
-
130
- ret = self.weight(**func_args)
131
- ret = be.astype(be.to_backend_array(ret), be._float_dtype)
132
- return {
133
- "data": ret,
134
- "shape": func_args["shape"],
135
- "return_real_fourier": func_args.get("return_real_fourier"),
136
- "is_multiplicative_filter": True,
137
- }
109
+ # Moved format conversion from __post__init
110
+ if "phase_shift" in init_kwargs:
111
+ init_kwargs["phase_shift"] = np.radians(init_kwargs["phase_shift"])
112
+ if "defocus_angle" in init_kwargs:
113
+ init_kwargs["defocus_angle"] = np.radians(init_kwargs["defocus_angle"])
114
+ return cls(**init_kwargs)
138
115
 
139
- def weight(
116
+ def _evaluate(
140
117
  self,
141
- shape: Tuple[int],
118
+ shape: Tuple[int, ...],
142
119
  defocus_x: Tuple[float],
143
120
  angles: Tuple[float],
144
121
  opening_axis: int = 2,
@@ -147,15 +124,13 @@ class CTF(ComposableFilter):
147
124
  phase_shift: Tuple[float] = 0,
148
125
  defocus_angle: Tuple[float] = 0,
149
126
  defocus_y: Tuple[float] = None,
150
- correct_defocus_gradient: bool = False,
151
127
  sampling_rate: Tuple[float] = 1,
152
128
  acceleration_voltage: float = 300e3,
153
- spherical_aberration: float = 2.7e3,
129
+ spherical_aberration: float = 2.7e7,
154
130
  flip_phase: bool = True,
155
- return_real_fourier: bool = False,
156
131
  cutoff_frequency: float = 0.5,
157
132
  **kwargs: Dict,
158
- ) -> NDArray:
133
+ ) -> Dict:
159
134
  """
160
135
  Compute the CTF weight tilt stack.
161
136
 
@@ -164,38 +139,31 @@ class CTF(ComposableFilter):
164
139
  shape : tuple of int
165
140
  The shape of the CTF.
166
141
  defocus_x : tuple of float
167
- The defocus in x direction (in units of sampling rate).
142
+ Defocus along the first principal axis in spatial units of sampling rate.
168
143
  angles : tuple of float
169
- The tilt angles.
144
+ The tilt angles in degrees.
170
145
  opening_axis : int, optional
171
146
  The axis around which the wedge is opened, defaults to 2.
172
147
  tilt_axis : int, optional
173
148
  The axis along which the tilt is applied, defaults to 0.
174
149
  amplitude_contrast : tuple of float, optional
175
- The amplitude contrast, defaults to 0.07.
150
+ Amplitude contrast of microscope, defaults to 0.07.
176
151
  phase_shift : tuple of float, optional
177
- The phase shift in radians, defaults to 0.
152
+ CTF phase shift in radians, defaults to 0.
178
153
  defocus_angle : tuple of float, optional
179
- The defocus angle in radians, defaults to 0.
154
+ Astigmatism angle in radians, defaults to 0.
180
155
  defocus_y : tuple of float, optional
181
- The defocus in x direction (in units of sampling rate).
182
- correct_defocus_gradient : bool, optional
183
- Whether to correct defocus gradient, defaults to False.
156
+ Defocus along the second principal axis in spatial units of sampling rate.
184
157
  sampling_rate : tuple of float, optional
185
158
  The sampling rate, defaults to 1.
186
159
  acceleration_voltage : float, optional
187
160
  The acceleration voltage in electron microscopy, defaults to 300e3.
188
161
  spherical_aberration : float, optional
189
- The spherical aberration coefficient, defaults to 2.7e3.
162
+ Spherical aberration of microscope in units of sampling rate.
190
163
  flip_phase : bool, optional
191
164
  Whether the returned CTF should be phase-flipped, defaults to True.
192
165
  **kwargs : Dict
193
166
  Additional keyword arguments.
194
-
195
- Returns
196
- -------
197
- NDArray
198
- A stack containing the CTF weight.
199
167
  """
200
168
  angles = np.atleast_1d(angles)
201
169
  defoci_x = pad_to_length(defocus_x, angles.size)
@@ -219,7 +187,6 @@ class CTF(ComposableFilter):
219
187
  corrected_tilt_axis -= 1
220
188
 
221
189
  for index, angle in enumerate(angles):
222
- correction = correct_defocus_gradient and angle is not None
223
190
  chi = create_ctf(
224
191
  angle=angle,
225
192
  shape=ctf_shape,
@@ -227,7 +194,6 @@ class CTF(ComposableFilter):
227
194
  defocus_y=defoci_y[index],
228
195
  sampling_rate=sampling_rate,
229
196
  acceleration_voltage=acceleration_voltage[index],
230
- correct_defocus_gradient=correction,
231
197
  spherical_aberration=spherical_aberration[index],
232
198
  cutoff_frequency=cutoff_frequency,
233
199
  phase_shift=phase_shift[index],
@@ -244,16 +210,16 @@ class CTF(ComposableFilter):
244
210
  stack = np.negative(stack, out=stack)
245
211
  if flip_phase:
246
212
  stack = np.abs(stack, out=stack)
247
- return be.to_backend_array(np.squeeze(stack))
213
+ return {"data": be.to_backend_array(stack), "shape": shape}
248
214
 
249
215
 
216
+ @dataclass
250
217
  class CTFReconstructed(CTF):
251
218
  """
252
- Create a simple contrast transfer function mask without the ability to specify
253
- per-tilt parameters like in :py:class:`CTF`.
219
+ Generate CTF filter for reconstructions.
254
220
  """
255
221
 
256
- def weight(
222
+ def _evaluate(
257
223
  self,
258
224
  shape: Tuple[int],
259
225
  defocus_x: Tuple[float],
@@ -265,10 +231,9 @@ class CTFReconstructed(CTF):
265
231
  acceleration_voltage: float = 300e3,
266
232
  spherical_aberration: float = 2.7e3,
267
233
  flip_phase: bool = True,
268
- return_real_fourier: bool = False,
269
234
  cutoff_frequency: float = 0.5,
270
235
  **kwargs: Dict,
271
- ) -> NDArray:
236
+ ) -> Dict:
272
237
  """
273
238
  Compute the CTF weight tilt stack.
274
239
 
@@ -277,17 +242,17 @@ class CTFReconstructed(CTF):
277
242
  shape : tuple of int
278
243
  The shape of the CTF.
279
244
  defocus_x : tuple of float
280
- The defocus in x direction in units of sampling rate.
245
+ Defocus along the first principal axis in spatial units of sampling rate.
281
246
  opening_axis : int, optional
282
247
  The axis around which the wedge is opened, defaults to 2.
283
248
  amplitude_contrast : float, optional
284
249
  The amplitude contrast, defaults to 0.07.
285
250
  phase_shift : tuple of float, optional
286
- The phase shift in radians, defaults to 0.
251
+ CTF phase shift in radians, defaults to 0.
287
252
  defocus_angle : tuple of float, optional
288
253
  The defocus angle in radians, defaults to 0.
289
254
  defocus_y : tuple of float, optional
290
- The defocus in y direction in units of sampling rate.
255
+ Defocus along the second principal axis in spatial units of sampling rate.
291
256
  sampling_rate : tuple of float, optional
292
257
  The sampling rate, defaults to 1.
293
258
  acceleration_voltage : float, optional
@@ -310,7 +275,6 @@ class CTFReconstructed(CTF):
310
275
  defocus_y=defocus_y,
311
276
  sampling_rate=np.max(sampling_rate),
312
277
  acceleration_voltage=self.acceleration_voltage,
313
- correct_defocus_gradient=False,
314
278
  spherical_aberration=spherical_aberration,
315
279
  cutoff_frequency=cutoff_frequency,
316
280
  phase_shift=phase_shift,
@@ -318,14 +282,10 @@ class CTFReconstructed(CTF):
318
282
  amplitude_contrast=amplitude_contrast,
319
283
  )
320
284
  # Avoid contrast inversion
321
- np.negative(stack, out=stack)
285
+ stack = np.negative(stack, out=stack)
322
286
  if flip_phase:
323
- np.abs(stack, out=stack)
324
-
325
- stack = shift_fourier(data=stack, shape_is_real_fourier=False)
326
- if return_real_fourier:
327
- stack = crop_real_fourier(stack)
328
- return be.to_backend_array(np.squeeze(stack))
287
+ stack = np.abs(stack, out=stack)
288
+ return {"data": be.to_backend_array(stack), "shape": shape}
329
289
 
330
290
 
331
291
  def _from_xml(filename: str) -> Dict:
@@ -501,7 +461,7 @@ def _from_mdoc(filename: str) -> Dict:
501
461
  return output
502
462
 
503
463
 
504
- def _compute_electron_wavelength(acceleration_voltage: int = None):
464
+ def _compute_electron_wavelength(acceleration_voltage: int = 300e3):
505
465
  """Computes the wavelength of an electron in angstrom."""
506
466
 
507
467
  # Physical constants expressed in SI units
@@ -524,14 +484,13 @@ def _compute_electron_wavelength(acceleration_voltage: int = None):
524
484
  def create_ctf(
525
485
  shape: Tuple[int],
526
486
  defocus_x: float,
527
- acceleration_voltage: float = 300,
487
+ acceleration_voltage: float = 300e3,
528
488
  defocus_angle: float = 0,
529
489
  phase_shift: float = 0,
530
490
  defocus_y: float = None,
531
491
  sampling_rate: float = 1,
532
492
  spherical_aberration: float = 2.7e7,
533
493
  amplitude_contrast: float = 0.07,
534
- correct_defocus_gradient: bool = False,
535
494
  cutoff_frequency: float = 0.5,
536
495
  angle: float = None,
537
496
  tilt_axis: int = 0,
@@ -546,15 +505,16 @@ def create_ctf(
546
505
  shape : Tuple[int]
547
506
  Shape of the returned CTF mask.
548
507
  defocus_x : float
549
- Defocus in x in units of sampling rate, e.g. 30000 Angstrom.
508
+ Defocus along the first principal axis in spatial units of sampling rate,
509
+ e.g. 30000 Angstrom.
550
510
  acceleration_voltage : float, optional
551
- Acceleration voltage in keV, defaults to 300.
511
+ Acceleration voltage in keV, defaults to 300e3.
552
512
  defocus_angle : float, optional
553
- Astigmatism in radians, defaults to 0.
513
+ Astigmatism angle in radians, defaults to 0.
554
514
  phase_shift : float, optional
555
- Phase shift from phase plate in radians, defaults to 0.
515
+ CTF phase shift in radians, defaults to 0.
556
516
  defocus_y : float, optional
557
- Defocus in y in units of sampling rate.
517
+ Defocus along the second principal axis in spatial units of sampling rate.
558
518
  tilt_axis : int, optional
559
519
  Axes the specimen was tilted over, defaults to 0 (x-axis).
560
520
  sampling_rate : float or tuple of floats
@@ -564,7 +524,7 @@ def create_ctf(
564
524
  spherical_aberration : float, optional
565
525
  Spherical aberration of microscope in units of sampling rate.
566
526
  angle : float, optional
567
- Assume the created CTF is a projection over opening_axis observed at angle.
527
+ Assume the created CTF is a projection observed at angle degrees.
568
528
  opening_axis : int, optional
569
529
  Projection axis, only relevant if angle is given.
570
530
  full_shape : tuple of ints
@@ -589,31 +549,18 @@ def create_ctf(
589
549
 
590
550
  defocus_x = defocus_x / sampling_rate if defocus_x is not None else None
591
551
  defocus_y = defocus_y / sampling_rate if defocus_y is not None else None
592
- if correct_defocus_gradient or defocus_y is not None:
552
+ if defocus_y is not None:
593
553
  if len(shape) < 2:
594
554
  raise ValueError(f"Length of shape needs to be at least 2, got {shape}")
595
555
 
596
- # Axial distance from grid center in multiples of sampling rate
597
- sampling = tuple(float(x) for x in np.divide(sampling_rate, shape))
556
+ # Axial distance from grid center in voxels
598
557
  grid = fftfreqn(
599
558
  shape=shape,
600
- sampling_rate=sampling,
559
+ sampling_rate=None,
601
560
  return_sparse_grid=True,
561
+ fftshift=False,
602
562
  )
603
563
 
604
- # This should be done after defocus_x computation
605
- if correct_defocus_gradient:
606
- if angle is None:
607
- raise ValueError("Cannot correct for defocus gradient without angle.")
608
-
609
- angle_rad = np.radians(angle)
610
- defocus_gradient = np.multiply(grid[tilt_axis], np.sin(angle_rad))
611
-
612
- if tilt_axis == 0:
613
- defocus_x = np.add(defocus_x, defocus_gradient)
614
- elif tilt_axis == 1 and defocus_y is not None:
615
- defocus_y = np.add(defocus_y, defocus_gradient)
616
-
617
564
  # 0.5 * (dx + dy) + cos(2 * (azimuth - astigmatism) * (dx - dy))
618
565
  if defocus_y is not None:
619
566
  defocus_sum = np.add(defocus_x, defocus_y)
@@ -628,7 +575,9 @@ def create_ctf(
628
575
  defocus_x = np.add(defocus_sum, defocus_difference)
629
576
  defocus_x *= 0.5
630
577
 
631
- frequency_grid = fftfreqn(shape, sampling_rate=1, compute_euclidean_norm=True)
578
+ frequency_grid = fftfreqn(
579
+ shape, sampling_rate=1, compute_euclidean_norm=True, fftshift=False
580
+ )
632
581
  if angle is not None and opening_axis is not None and full_shape is not None:
633
582
  frequency_grid = frequency_grid_at_angle(
634
583
  shape=full_shape,
@@ -636,8 +585,9 @@ def create_ctf(
636
585
  opening_axis=opening_axis,
637
586
  angle=angle,
638
587
  sampling_rate=1,
588
+ fftshift=False,
639
589
  )
640
- frequency_mask = frequency_grid < cutoff_frequency
590
+ frequency_mask = frequency_grid <= cutoff_frequency
641
591
 
642
592
  # k^2*π*λ(dx - 0.5 * sph_abb * λ^2 * k^2) + phase_shift + ampl_contrast_term)
643
593
  frequency_grid = np.square(frequency_grid, out=frequency_grid)
@@ -652,5 +602,4 @@ def create_ctf(
652
602
  )
653
603
  )
654
604
  chi = np.sin(-chi, out=chi)
655
- chi = np.multiply(chi, frequency_mask, out=chi)
656
- return chi
605
+ return np.multiply(chi, frequency_mask, out=chi)
@@ -1,22 +1,22 @@
1
1
  """
2
- Defines filters on tomographic tilt series.
2
+ Implements class ReconstructFromTilt and ShiftFourier.
3
3
 
4
4
  Copyright (c) 2024 European Molecular Biology Laboratory
5
5
 
6
6
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
 
9
- from typing import Tuple
9
+ from typing import Tuple, Dict
10
10
  from dataclasses import dataclass
11
11
 
12
12
  import numpy as np
13
13
 
14
+ from ..types import BackendArray
14
15
  from ..backends import backend as be
15
- from ..types import NDArray, BackendArray
16
16
 
17
17
  from .compose import ComposableFilter
18
18
  from ..rotations import euler_to_rotationmatrix
19
- from ._utils import crop_real_fourier, shift_fourier, create_reconstruction_filter
19
+ from ._utils import shift_fourier, create_reconstruction_filter
20
20
 
21
21
  __all__ = ["ReconstructFromTilt", "ShiftFourier"]
22
22
 
@@ -24,45 +24,56 @@ __all__ = ["ReconstructFromTilt", "ShiftFourier"]
24
24
  @dataclass
25
25
  class ReconstructFromTilt(ComposableFilter):
26
26
  """
27
- Reconstruct a d+1 array from a d-dimensional input projection using weighted
28
- backprojection (WBP).
27
+ Place Fourier transforms of d-dimensional inputs into a d+1-dimensional array
28
+ aking of weighted backprojection using direct fourier inversion.
29
+
30
+ This class is used to reconstruct the output of ComposableFilter instances for
31
+ individual tilts to be applied to query templates.
32
+
33
+ See Also
34
+ --------
35
+ :py:class:`tme.filters.CTF`
36
+ :py:class:`tme.filters.Wedge`
37
+ :py:class:`tme.filters.BandPass`
38
+
29
39
  """
30
40
 
31
- #: Shape of the reconstruction.
32
- shape: Tuple[int] = None
33
- #: Angle of each individual tilt.
41
+ #: Angle of each individual tilt in degrees.
34
42
  angles: Tuple[float] = None
35
43
  #: Projection axis, defaults to 2 (z).
36
44
  opening_axis: int = 2
37
45
  #: Tilt axis, defaults to 0 (x).
38
46
  tilt_axis: int = 0
39
- #: Whether to return a share compliant with rfftn.
40
- return_real_fourier: bool = True
41
47
  #: Interpolation order used for rotation
42
48
  interpolation_order: int = 1
43
49
  #: Filter window applied during reconstruction.
44
50
  reconstruction_filter: str = None
45
51
 
46
- def __call__(self, return_real_fourier: bool = False, **kwargs):
52
+ @staticmethod
53
+ def _evaluate(
54
+ data: BackendArray,
55
+ shape: Tuple[int, ...],
56
+ angles: Tuple[float],
57
+ opening_axis: int = 2,
58
+ tilt_axis: int = 0,
59
+ interpolation_order: int = 1,
60
+ reconstruction_filter: str = None,
61
+ **kwargs,
62
+ ) -> Dict:
47
63
  """
48
- Reconstruct a d+1 array from a d-dimensional input using WBP.
64
+ Reconstruct a 3-dimensional array from n 2-dimensional inputs using WBP.
49
65
 
50
66
  Parameters
51
67
  ----------
52
- shape : tuple of int
53
- The shape of the reconstruction volume.
54
68
  data : BackendArray
55
69
  D-dimensional image stack with shape (n, ...). The data is assumed to be
56
- a Fourier transform of the stack you are trying to reconstruct with
57
- DC component in the center.
70
+ the Fourier transform of the stack you are trying to reconstruct with
71
+ DC component at the origin. Notably, the data needs to be the output of
72
+ np.fft.fftn not the reduced np.fft.rffn.
73
+ shape : tuple of int
74
+ The shape of the reconstruction volume.
58
75
  angles : tuple of float
59
- Angle of each individual tilt.
60
- return_real_fourier : bool, optional
61
- Return a shape compliant
62
- return_real_fourier : tuple of int
63
- Return a shape compliant with rfft, i.e., omit the negative frequencies
64
- terms resulting in a return shape (*shape[:-1], shape[-1]//2+1). Defaults
65
- to False.
76
+ Angle to place individual slices at in degrees.
66
77
  reconstruction_filter : bool, optional
67
78
  Filter window applied during reconstruction.
68
79
  See :py:meth:`create_reconstruction_filter` for available options.
@@ -70,80 +81,21 @@ class ReconstructFromTilt(ComposableFilter):
70
81
  Axis the plane is tilted over, defaults to 0 (x).
71
82
  opening_axis : int
72
83
  The projection axis, defaults to 2 (z).
73
-
74
- Returns
75
- -------
76
- dict
77
- data: BackendArray
78
- The filter mask.
79
- shape: tuple of ints
80
- The requested filter shape
81
- return_real_fourier: bool
82
- Whether data is compliant with rfftn.
83
- is_multiplicative_filter: bool
84
- Whether the filter is multiplicative in Fourier space.
85
- """
86
-
87
- func_args = vars(self).copy()
88
- func_args.update(kwargs)
89
-
90
- ret = self.reconstruct(**func_args)
91
-
92
- ret = shift_fourier(data=ret, shape_is_real_fourier=False)
93
- if return_real_fourier:
94
- ret = crop_real_fourier(ret)
95
-
96
- return {
97
- "data": ret,
98
- "shape": func_args["shape"],
99
- "return_real_fourier": return_real_fourier,
100
- "is_multiplicative_filter": False,
101
- }
102
-
103
- @staticmethod
104
- def reconstruct(
105
- data: NDArray,
106
- shape: Tuple[int],
107
- angles: Tuple[float],
108
- opening_axis: int,
109
- tilt_axis: int,
110
- interpolation_order: int = 1,
111
- reconstruction_filter: str = None,
112
- **kwargs,
113
- ):
114
84
  """
115
- Reconstruct a volume from a tilt series.
116
-
117
- Parameters
118
- ----------
119
- data : NDArray
120
- The Fourier transform of tilt series data.
121
- shape : tuple of int
122
- Shape of the reconstruction.
123
- angles : tuple of float
124
- Angle of each individual tilt.
125
- opening_axis : int
126
- The axis around which the volume is opened.
127
- tilt_axis : int
128
- Axis the plane is tilted over.
129
- interpolation_order : int, optional
130
- Interpolation order used for rotation, defaults to 1.
131
- reconstruction_filter : bool, optional
132
- Filter window applied during reconstruction.
133
- See :py:meth:`create_reconstruction_filter` for available options.
134
85
 
135
- Returns
136
- -------
137
- NDArray
138
- The reconstructed volume.
139
- """
140
86
  if data.shape == shape:
141
87
  return data
142
88
 
143
- data = be.to_backend_array(data)
89
+ # Composable filters use frequency grids centered at the origin
90
+ # Here we require them to be centered at subset.shape // 2
91
+ for i in range(data.shape[0]):
92
+ data_shifted = shift_fourier(
93
+ data[i], shape_is_real_fourier=False, ifftshift=False
94
+ )
95
+ data = be.at(data, i, data_shifted)
96
+
144
97
  volume_temp = be.zeros(shape, dtype=data.dtype)
145
- volume_temp_rotated = be.zeros(shape, dtype=data.dtype)
146
- volume = be.zeros(shape, dtype=data.dtype)
98
+ rec = be.zeros(shape, dtype=data.dtype)
147
99
 
148
100
  slices = tuple(slice(a // 2, (a // 2) + 1) for a in shape)
149
101
  subset = tuple(
@@ -162,58 +114,51 @@ class ReconstructFromTilt(ComposableFilter):
162
114
  filter_type=reconstruction_filter,
163
115
  filter_shape=(shape[tilt_axis],),
164
116
  tilt_angles=angles,
117
+ fftshift=True,
165
118
  )
166
119
  rec_shape = tuple(1 if i != tilt_axis else x for i, x in enumerate(shape))
167
120
  rec_filter = be.to_backend_array(rec_filter)
168
121
  rec_filter = be.reshape(rec_filter, rec_shape)
169
122
 
170
123
  angles = be.to_backend_array(angles)
124
+ axis_index = min(
125
+ tuple(i for i in range(len(shape)) if i not in (tilt_axis, opening_axis))
126
+ )
171
127
  for index in range(len(angles)):
172
128
  angles_loop = be.fill(angles_loop, 0)
173
129
  volume_temp = be.fill(volume_temp, 0)
174
- volume_temp_rotated = be.fill(volume_temp_rotated, 0)
175
130
 
176
131
  # Jax compatibility
177
132
  volume_temp = be.at(volume_temp, subset, wedges[index] * rec_filter)
178
- angles_loop = be.at(angles_loop, tilt_axis, angles[index])
133
+ angles_loop = be.at(angles_loop, axis_index, angles[index])
179
134
 
180
- angles_loop = be.roll(angles_loop, (opening_axis - 1,), axis=0)
181
- rotation_matrix = euler_to_rotationmatrix(be.to_numpy_array(angles_loop))
182
- rotation_matrix = be.to_backend_array(rotation_matrix)
135
+ # We want a push rotation but rigid transform assumes pull
136
+ rotation_matrix = euler_to_rotationmatrix(
137
+ be.to_numpy_array(angles_loop), seq="xyz"
138
+ ).T
183
139
 
184
- volume_temp_rotated, _ = be.rigid_transform(
140
+ volume_temp, _ = be.rigid_transform(
185
141
  arr=volume_temp,
186
- rotation_matrix=rotation_matrix,
187
- out=volume_temp_rotated,
142
+ rotation_matrix=be.to_backend_array(rotation_matrix),
188
143
  use_geometric_center=True,
189
144
  order=interpolation_order,
190
145
  )
191
- volume = be.add(volume, volume_temp_rotated, out=volume)
146
+ rec = be.add(rec, volume_temp, out=rec)
192
147
 
193
- return volume
148
+ # Shift DC component back to origin
149
+ rec = shift_fourier(rec, shape_is_real_fourier=False, ifftshift=True)
150
+ return {"data": rec, "shape": shape, "is_multiplicative_filter": False}
194
151
 
195
152
 
196
153
  class ShiftFourier(ComposableFilter):
197
- def __call__(
198
- self,
199
- data: BackendArray,
200
- shape_is_real_fourier: bool = False,
201
- return_real_fourier: bool = True,
202
- **kwargs,
203
- ):
154
+ def _evaluate(self, shape: Tuple[int, ...], data: BackendArray, **kwargs) -> Dict:
204
155
  ret = []
205
156
  for index in range(data.shape[0]):
206
- mask = be.to_numpy_array(data[index])
207
-
208
- mask = shift_fourier(data=mask, shape_is_real_fourier=shape_is_real_fourier)
209
- if return_real_fourier:
210
- mask = crop_real_fourier(mask)
157
+ mask = shift_fourier(
158
+ data=data[index],
159
+ shape_is_real_fourier=kwargs.get("return_real_fourier", False),
160
+ )
211
161
  ret.append(mask[None])
212
- ret = np.concatenate(ret, axis=0)
213
-
214
- return {
215
- "data": ret,
216
- "shape": kwargs.get("shape"),
217
- "return_real_fourier": return_real_fourier,
218
- "is_multiplicative_filter": False,
219
- }
162
+
163
+ ret = be.concatenate(ret, axis=0)
164
+ return {"data": ret, "shape": shape, "is_multiplicative_filter": False}