pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
@@ -0,0 +1,160 @@
1
+ """ Defines filters on tomographic tilt series.
2
+
3
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ from typing import Tuple
9
+ from dataclasses import dataclass
10
+
11
+ import numpy as np
12
+
13
+ from ..types import NDArray
14
+ from ..backends import backend as be
15
+
16
+ from .compose import ComposableFilter
17
+ from ..rotations import euler_to_rotationmatrix
18
+ from ._utils import (
19
+ crop_real_fourier,
20
+ shift_fourier,
21
+ create_reconstruction_filter,
22
+ )
23
+
24
+ __all__ = ["ReconstructFromTilt"]
25
+
26
+
27
+ @dataclass
28
+ class ReconstructFromTilt(ComposableFilter):
29
+ """Reconstruct a volume from a tilt series."""
30
+
31
+ #: Shape of the reconstruction.
32
+ shape: Tuple[int] = None
33
+ #: Angle of each individual tilt.
34
+ angles: Tuple[float] = None
35
+ #: The axis around which the volume is opened.
36
+ opening_axis: int = 0
37
+ #: Axis the plane is tilted over.
38
+ tilt_axis: int = 2
39
+ #: Whether to return a share compliant with rfftn.
40
+ return_real_fourier: bool = True
41
+ #: Interpolation order used for rotation
42
+ interpolation_order: int = 1
43
+ #: Filter window applied during reconstruction.
44
+ reconstruction_filter: str = None
45
+
46
+ def __call__(self, **kwargs):
47
+ func_args = vars(self).copy()
48
+ func_args.update(kwargs)
49
+
50
+ ret = self.reconstruct(**func_args)
51
+
52
+ return {
53
+ "data": ret,
54
+ "shape": ret.shape,
55
+ "shape_is_real_fourier": func_args["return_real_fourier"],
56
+ "angles": func_args["angles"],
57
+ "tilt_axis": func_args["tilt_axis"],
58
+ "opening_axis": func_args["opening_axis"],
59
+ "is_multiplicative_filter": False,
60
+ }
61
+
62
+ @staticmethod
63
+ def reconstruct(
64
+ data: NDArray,
65
+ shape: Tuple[int],
66
+ angles: Tuple[float],
67
+ opening_axis: int,
68
+ tilt_axis: int,
69
+ interpolation_order: int = 1,
70
+ return_real_fourier: bool = True,
71
+ reconstruction_filter: str = None,
72
+ **kwargs,
73
+ ):
74
+ """
75
+ Reconstruct a volume from a tilt series.
76
+
77
+ Parameters
78
+ ----------
79
+ data : NDArray
80
+ The tilt series data.
81
+ shape : tuple of int
82
+ Shape of the reconstruction.
83
+ angles : tuple of float
84
+ Angle of each individual tilt.
85
+ opening_axis : int
86
+ The axis around which the volume is opened.
87
+ tilt_axis : int
88
+ Axis the plane is tilted over.
89
+ interpolation_order : int, optional
90
+ Interpolation order used for rotation, defaults to 1.
91
+ return_real_fourier : bool, optional
92
+ Whether to return a shape compliant with rfftn, defaults to True.
93
+ reconstruction_filter : bool, optional
94
+ Filter window applied during reconstruction.
95
+ See :py:meth:`create_reconstruction_filter` for available options.
96
+
97
+ Returns
98
+ -------
99
+ NDArray
100
+ The reconstructed volume.
101
+ """
102
+ if data.shape == shape:
103
+ return data
104
+
105
+ data = be.to_backend_array(data)
106
+ volume_temp = be.zeros(shape, dtype=be._float_dtype)
107
+ volume_temp_rotated = be.zeros(shape, dtype=be._float_dtype)
108
+ volume = be.zeros(shape, dtype=be._float_dtype)
109
+
110
+ slices = tuple(slice(a // 2, (a // 2) + 1) for a in shape)
111
+ subset = tuple(
112
+ slice(None) if i != opening_axis else x for i, x in enumerate(slices)
113
+ )
114
+ angles_loop = be.zeros(len(shape))
115
+ wedge_dim = [x for x in data.shape]
116
+ wedge_dim.insert(1 + opening_axis, 1)
117
+ wedges = be.reshape(data, wedge_dim)
118
+
119
+ rec_filter = 1
120
+ aspect_ratio = shape[opening_axis] / shape[tilt_axis]
121
+ angles = np.degrees(np.arctan(np.tan(np.radians(angles)) * aspect_ratio))
122
+ if reconstruction_filter is not None:
123
+ rec_filter = create_reconstruction_filter(
124
+ filter_type=reconstruction_filter,
125
+ filter_shape=(shape[tilt_axis],),
126
+ tilt_angles=angles,
127
+ )
128
+ rec_shape = tuple(1 if i != tilt_axis else x for i, x in enumerate(shape))
129
+ rec_filter = be.to_backend_array(rec_filter)
130
+ rec_filter = be.reshape(rec_filter, rec_shape)
131
+
132
+ angles = be.to_backend_array(angles)
133
+ for index in range(len(angles)):
134
+ angles_loop = be.fill(angles_loop, 0)
135
+ volume_temp = be.fill(volume_temp, 0)
136
+ volume_temp_rotated = be.fill(volume_temp_rotated, 0)
137
+
138
+ # Jax compatibility
139
+ volume_temp = be.at(volume_temp, subset, wedges[index] * rec_filter)
140
+ angles_loop = be.at(angles_loop, tilt_axis, angles[index])
141
+
142
+ angles_loop = be.roll(angles_loop, (opening_axis - 1,), axis=0)
143
+ rotation_matrix = euler_to_rotationmatrix(be.to_numpy_array(angles_loop))
144
+ rotation_matrix = be.to_backend_array(rotation_matrix)
145
+
146
+ volume_temp_rotated, _ = be.rigid_transform(
147
+ arr=volume_temp,
148
+ rotation_matrix=rotation_matrix,
149
+ out=volume_temp_rotated,
150
+ use_geometric_center=True,
151
+ order=interpolation_order,
152
+ )
153
+ volume = be.add(volume, volume_temp_rotated, out=volume)
154
+
155
+ volume = shift_fourier(data=volume, shape_is_real_fourier=False)
156
+
157
+ if return_real_fourier:
158
+ volume = crop_real_fourier(volume)
159
+
160
+ return volume
tme/filters/wedge.py ADDED
@@ -0,0 +1,542 @@
1
+ """ Implements class Wedge and WedgeReconstructed to create Fourier
2
+ filter representations.
3
+
4
+ Copyright (c) 2024 European Molecular Biology Laboratory
5
+
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ """
8
+
9
+ from typing import Tuple, Dict
10
+
11
+ import numpy as np
12
+
13
+ from ..types import NDArray
14
+ from ..backends import backend as be
15
+ from .compose import ComposableFilter
16
+ from ..matching_utils import centered
17
+ from ..rotations import euler_to_rotationmatrix
18
+ from ._utils import (
19
+ centered_grid,
20
+ frequency_grid_at_angle,
21
+ compute_tilt_shape,
22
+ crop_real_fourier,
23
+ fftfreqn,
24
+ shift_fourier,
25
+ create_reconstruction_filter,
26
+ )
27
+
28
+ __all__ = ["Wedge", "WedgeReconstructed"]
29
+
30
+
31
+ class Wedge(ComposableFilter):
32
+ """
33
+ Generate wedge mask for tomographic data.
34
+
35
+ Parameters
36
+ ----------
37
+ shape : tuple of int
38
+ The shape of the reconstruction volume.
39
+ tilt_axis : int
40
+ Axis the plane is tilted over.
41
+ opening_axis : int
42
+ The axis around which the volume is opened.
43
+ angles : tuple of float
44
+ The tilt angles.
45
+ weights : tuple of float
46
+ The weights corresponding to each tilt angle.
47
+ weight_type : str, optional
48
+ The type of weighting to apply, defaults to None.
49
+ frequency_cutoff : float, optional
50
+ Frequency cutoff for created mask. Nyquist 0.5 by default.
51
+
52
+ Returns
53
+ -------
54
+ Dict
55
+ A dictionary containing weighted wedges and related information.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ shape: Tuple[int],
61
+ tilt_axis: int,
62
+ opening_axis: int,
63
+ angles: Tuple[float],
64
+ weights: Tuple[float],
65
+ weight_type: str = None,
66
+ frequency_cutoff: float = 0.5,
67
+ ):
68
+ self.shape = shape
69
+ self.tilt_axis = tilt_axis
70
+ self.opening_axis = opening_axis
71
+ self.angles = angles
72
+ self.weights = weights
73
+ self.frequency_cutoff = frequency_cutoff
74
+
75
+ @classmethod
76
+ def from_file(cls, filename: str) -> "Wedge":
77
+ """
78
+ Generate a :py:class:`Wedge` instance by reading tilt angles and weights
79
+ from a tab-separated text file.
80
+
81
+ Parameters
82
+ ----------
83
+ filename : str
84
+ The path to the file containing tilt angles and weights.
85
+
86
+ Returns
87
+ -------
88
+ :py:class:`Wedge`
89
+ Class instance instance initialized with angles and weights from the file.
90
+ """
91
+ data = cls._from_text(filename)
92
+
93
+ angles, weights = data.get("angles", None), data.get("weights", None)
94
+ if angles is None:
95
+ raise ValueError(f"Could not find colum angles in {filename}")
96
+
97
+ if weights is None:
98
+ weights = [1] * len(angles)
99
+
100
+ if len(weights) != len(angles):
101
+ raise ValueError("Length of weights and angles differ.")
102
+
103
+ return cls(
104
+ shape=None,
105
+ tilt_axis=0,
106
+ opening_axis=2,
107
+ angles=np.array(angles, dtype=np.float32),
108
+ weights=np.array(weights, dtype=np.float32),
109
+ )
110
+
111
+ @staticmethod
112
+ def _from_text(filename: str, delimiter="\t") -> Dict:
113
+ """
114
+ Read column data from a text file.
115
+
116
+ Parameters
117
+ ----------
118
+ filename : str
119
+ The path to the text file.
120
+ delimiter : str, optional
121
+ The delimiter used in the file, defaults to '\t'.
122
+
123
+ Returns
124
+ -------
125
+ Dict
126
+ A dictionary with one key for each column.
127
+ """
128
+ with open(filename, mode="r", encoding="utf-8") as infile:
129
+ data = [x.strip() for x in infile.read().split("\n")]
130
+ data = [x.split("\t") for x in data if len(x)]
131
+
132
+ headers = data.pop(0)
133
+ ret = {header: list(column) for header, column in zip(headers, zip(*data))}
134
+
135
+ return ret
136
+
137
+ def __call__(self, **kwargs: Dict) -> NDArray:
138
+ func_args = vars(self).copy()
139
+ func_args.update(kwargs)
140
+
141
+ weight_types = {
142
+ None: self.weight_angle,
143
+ "angle": self.weight_angle,
144
+ "relion": self.weight_relion,
145
+ "grigorieff": self.weight_grigorieff,
146
+ }
147
+
148
+ weight_type = func_args.get("weight_type", None)
149
+ if weight_type not in weight_types:
150
+ raise ValueError(
151
+ f"Supported weight_types are {','.join(list(weight_types.keys()))}"
152
+ )
153
+
154
+ if weight_type == "angle":
155
+ func_args["weights"] = np.cos(np.radians(self.angles))
156
+
157
+ ret = weight_types[weight_type](**func_args)
158
+
159
+ frequency_cutoff = func_args.get("frequency_cutoff", None)
160
+ if frequency_cutoff is not None:
161
+ for index, angle in enumerate(func_args["angles"]):
162
+ frequency_grid = frequency_grid_at_angle(
163
+ shape=func_args["shape"],
164
+ opening_axis=self.opening_axis,
165
+ tilt_axis=self.tilt_axis,
166
+ angle=angle,
167
+ sampling_rate=1,
168
+ )
169
+ ret[index] = np.multiply(ret[index], frequency_grid <= frequency_cutoff)
170
+
171
+ ret = be.astype(be.to_backend_array(ret), be._float_dtype)
172
+
173
+ return {
174
+ "data": ret,
175
+ "angles": func_args["angles"],
176
+ "tilt_axis": func_args["tilt_axis"],
177
+ "opening_axis": func_args["opening_axis"],
178
+ "sampling_rate": func_args.get("sampling_rate", 1),
179
+ "is_multiplicative_filter": True,
180
+ }
181
+
182
+ @staticmethod
183
+ def weight_angle(
184
+ shape: Tuple[int],
185
+ weights: Tuple[float],
186
+ angles: Tuple[float],
187
+ opening_axis: int,
188
+ tilt_axis: int,
189
+ **kwargs,
190
+ ) -> NDArray:
191
+ """
192
+ Generate weighted wedges based on the cosine of the current angle.
193
+ """
194
+ tilt_shape = compute_tilt_shape(
195
+ shape=shape, opening_axis=opening_axis, reduce_dim=True
196
+ )
197
+ wedge, wedges = np.ones(tilt_shape), np.zeros((len(angles), *tilt_shape))
198
+ for index, angle in enumerate(angles):
199
+ wedge.fill(weights[index])
200
+ wedges[index] = wedge
201
+
202
+ return wedges
203
+
204
+ def weight_relion(
205
+ self, shape: Tuple[int], opening_axis: int, tilt_axis: int, **kwargs
206
+ ) -> NDArray:
207
+ """
208
+ Generate weighted wedges based on the RELION 1.4 formalism, weighting each
209
+ angle using the cosine of the angle and a Gaussian lowpass filter computed
210
+ with respect to the exposure per angstrom.
211
+
212
+ Returns
213
+ -------
214
+ NDArray
215
+ Weighted wedges.
216
+ """
217
+ tilt_shape = compute_tilt_shape(
218
+ shape=shape, opening_axis=opening_axis, reduce_dim=True
219
+ )
220
+
221
+ wedges = np.zeros((len(self.angles), *tilt_shape))
222
+ for index, angle in enumerate(self.angles):
223
+ frequency_grid = frequency_grid_at_angle(
224
+ shape=shape,
225
+ opening_axis=opening_axis,
226
+ tilt_axis=tilt_axis,
227
+ angle=angle,
228
+ sampling_rate=1,
229
+ )
230
+ sigma = np.sqrt(self.weights[index] * 4 / (8 * np.pi**2))
231
+ sigma = -2 * np.pi**2 * sigma**2
232
+ np.square(frequency_grid, out=frequency_grid)
233
+ np.multiply(sigma, frequency_grid, out=frequency_grid)
234
+ np.exp(frequency_grid, out=frequency_grid)
235
+ np.multiply(frequency_grid, np.cos(np.radians(angle)), out=frequency_grid)
236
+ wedges[index] = frequency_grid
237
+
238
+ return wedges
239
+
240
+ def weight_grigorieff(
241
+ self,
242
+ shape: Tuple[int],
243
+ opening_axis: int,
244
+ tilt_axis: int,
245
+ amplitude: float = 0.245,
246
+ power: float = -1.665,
247
+ offset: float = 2.81,
248
+ **kwargs,
249
+ ) -> NDArray:
250
+ """
251
+ Generate weighted wedges based on the formalism introduced in [1]_.
252
+
253
+ Returns
254
+ -------
255
+ NDArray
256
+ Weighted wedges.
257
+
258
+ References
259
+ ----------
260
+ .. [1] Timothy Grant, Nikolaus Grigorieff (2015), eLife 4:e06980.
261
+ """
262
+ tilt_shape = compute_tilt_shape(
263
+ shape=shape, opening_axis=opening_axis, reduce_dim=True
264
+ )
265
+
266
+ wedges = np.zeros((len(self.angles), *tilt_shape), dtype=be._float_dtype)
267
+ for index, angle in enumerate(self.angles):
268
+ frequency_grid = frequency_grid_at_angle(
269
+ shape=shape,
270
+ opening_axis=opening_axis,
271
+ tilt_axis=tilt_axis,
272
+ angle=angle,
273
+ sampling_rate=1,
274
+ )
275
+
276
+ with np.errstate(divide="ignore"):
277
+ np.power(frequency_grid, power, out=frequency_grid)
278
+ np.multiply(amplitude, frequency_grid, out=frequency_grid)
279
+ np.add(frequency_grid, offset, out=frequency_grid)
280
+ np.multiply(-2, frequency_grid, out=frequency_grid)
281
+ np.divide(
282
+ self.weights[index],
283
+ frequency_grid,
284
+ out=frequency_grid,
285
+ )
286
+
287
+ wedges[index] = np.exp(frequency_grid)
288
+
289
+ return wedges
290
+
291
+
292
+ class WedgeReconstructed:
293
+ """
294
+ Initialize :py:class:`WedgeReconstructed`.
295
+
296
+ Parameters
297
+ ----------
298
+ angles :tuple of float, optional
299
+ The tilt angles, defaults to None.
300
+ opening_axis : int, optional
301
+ The axis around which the wedge is opened.
302
+ tilt_axis : int, optional
303
+ The axis along which the tilt is applied.
304
+ weights : tuple of float, optional
305
+ Weights to assign to individual wedge components.
306
+ weight_wedge : bool, optional
307
+ Whether individual wedge components should be weighted. If True and weights
308
+ is None, uses the cosine of the angle otherwise weights.
309
+ create_continuous_wedge: bool, optional
310
+ Whether to create a continous wedge or a per-component wedge. Weights are only
311
+ considered for non-continuous wedges.
312
+ frequency_cutoff : float, optional
313
+ Filter window applied during reconstruction.
314
+ **kwargs : Dict
315
+ Additional keyword arguments.
316
+ """
317
+
318
+ def __init__(
319
+ self,
320
+ opening_axis: int,
321
+ tilt_axis: int,
322
+ angles: Tuple[float] = None,
323
+ weights: Tuple[float] = None,
324
+ weight_wedge: bool = False,
325
+ create_continuous_wedge: bool = False,
326
+ frequency_cutoff: float = 0.5,
327
+ reconstruction_filter: str = None,
328
+ **kwargs: Dict,
329
+ ):
330
+ self.angles = angles
331
+ self.opening_axis = opening_axis
332
+ self.tilt_axis = tilt_axis
333
+ self.weights = weights
334
+ self.weight_wedge = weight_wedge
335
+ self.reconstruction_filter = reconstruction_filter
336
+ self.create_continuous_wedge = create_continuous_wedge
337
+ self.frequency_cutoff = frequency_cutoff
338
+
339
+ def __call__(self, shape: Tuple[int], **kwargs: Dict) -> Dict:
340
+ """
341
+ Generate the reconstructed wedge.
342
+
343
+ Parameters
344
+ ----------
345
+ shape : tuple of int
346
+ The shape of the reconstruction volume.
347
+ **kwargs : Dict
348
+ Additional keyword arguments.
349
+
350
+ Returns
351
+ -------
352
+ Dict
353
+ A dictionary containing the reconstructed wedge and related information.
354
+ """
355
+ func_args = vars(self).copy()
356
+ func_args.update(kwargs)
357
+
358
+ if kwargs.get("is_fourier_shape", False):
359
+ print("Cannot create continuous wedge mask based on real fourier shape.")
360
+
361
+ func = self.step_wedge
362
+ if func_args.get("create_continuous_wedge", False):
363
+ func = self.continuous_wedge
364
+
365
+ weight_wedge = func_args.get("weight_wedge", False)
366
+ if func_args.get("wedge_weights") is None and weight_wedge:
367
+ func_args["weights"] = np.cos(
368
+ np.radians(be.to_numpy_array(func_args.get("angles", (0,))))
369
+ )
370
+
371
+ ret = func(shape=shape, **func_args)
372
+
373
+ frequency_cutoff = func_args.get("frequency_cutoff", None)
374
+ if frequency_cutoff is not None:
375
+ frequency_mask = fftfreqn(
376
+ shape=shape,
377
+ sampling_rate=1,
378
+ compute_euclidean_norm=True,
379
+ shape_is_real_fourier=False,
380
+ )
381
+ ret = np.multiply(ret, frequency_mask <= frequency_cutoff, out=ret)
382
+
383
+ if not weight_wedge:
384
+ ret = (ret > 0) * 1.0
385
+
386
+ ret = be.astype(be.to_backend_array(ret), be._float_dtype)
387
+
388
+ ret = shift_fourier(data=ret, shape_is_real_fourier=False)
389
+ if func_args.get("return_real_fourier", False):
390
+ ret = crop_real_fourier(ret)
391
+
392
+ return {
393
+ "data": ret,
394
+ "shape_is_real_fourier": func_args["return_real_fourier"],
395
+ "shape": ret.shape,
396
+ "tilt_axis": func_args["tilt_axis"],
397
+ "opening_axis": func_args["opening_axis"],
398
+ "is_multiplicative_filter": True,
399
+ "angles": func_args["angles"],
400
+ }
401
+
402
+ @staticmethod
403
+ def continuous_wedge(
404
+ shape: Tuple[int],
405
+ angles: Tuple[float],
406
+ opening_axis: int,
407
+ tilt_axis: int,
408
+ **kwargs: Dict,
409
+ ) -> NDArray:
410
+ """
411
+ Generate a continous wedge mask with DC component at the center.
412
+
413
+ Parameters
414
+ ----------
415
+ shape : tuple of int
416
+ The shape of the reconstruction volume.
417
+ angles : tuple of float
418
+ Start and stop tilt angle.
419
+ opening_axis : int
420
+ The axis around which the wedge is opened.
421
+ tilt_axis : int
422
+ The axis along which the tilt is applied.
423
+
424
+ Returns
425
+ -------
426
+ NDArray
427
+ Wedge mask.
428
+ """
429
+ aspect_ratio = shape[opening_axis] / shape[tilt_axis]
430
+ angles = np.degrees(np.arctan(np.tan(np.radians(angles)) * aspect_ratio))
431
+
432
+ start_radians = np.tan(np.radians(90 - angles[0]))
433
+ stop_radians = np.tan(np.radians(-1 * (90 - angles[1])))
434
+
435
+ grid = centered_grid(shape)
436
+ with np.errstate(divide="ignore", invalid="ignore"):
437
+ ratios = np.where(
438
+ grid[opening_axis] == 0,
439
+ np.tan(np.radians(90)) + 1,
440
+ grid[tilt_axis] / grid[opening_axis],
441
+ )
442
+
443
+ wedge = np.logical_or(start_radians <= ratios, stop_radians >= ratios).astype(
444
+ np.float32
445
+ )
446
+
447
+ return wedge
448
+
449
+ @staticmethod
450
+ def step_wedge(
451
+ shape: Tuple[int],
452
+ angles: Tuple[float],
453
+ opening_axis: int,
454
+ tilt_axis: int,
455
+ weights: Tuple[float] = None,
456
+ reconstruction_filter: str = None,
457
+ **kwargs: Dict,
458
+ ) -> NDArray:
459
+ """
460
+ Generate a per-angle wedge shape with DC component at the center.
461
+
462
+ Parameters
463
+ ----------
464
+ shape : tuple of int
465
+ The shape of the reconstruction volume.
466
+ angles : tuple of float
467
+ The tilt angles.
468
+ opening_axis : int
469
+ The axis around which the wedge is opened.
470
+ tilt_axis : int
471
+ The axis along which the tilt is applied.
472
+ reconstruction_filter : str
473
+ Filter used during reconstruction.
474
+ weights : tuple of float, optional
475
+ Weights to assign to individual tilts. Defaults to 1.
476
+
477
+ Returns
478
+ -------
479
+ NDArray
480
+ Wege mask.
481
+ """
482
+ from ..backends import NumpyFFTWBackend
483
+
484
+ angles = np.asarray(be.to_numpy_array(angles))
485
+
486
+ if weights is None:
487
+ weights = np.ones(angles.size)
488
+ weights = np.asarray(weights)
489
+
490
+ shape = tuple(int(x) for x in shape)
491
+ opening_axis, tilt_axis = int(opening_axis), int(tilt_axis)
492
+
493
+ weights = np.repeat(weights, angles.size // weights.size)
494
+ plane = np.zeros(
495
+ (shape[opening_axis], shape[tilt_axis] + (1 - shape[tilt_axis] % 2)),
496
+ dtype=np.float32,
497
+ )
498
+
499
+ aspect_ratio = plane.shape[0] / plane.shape[1]
500
+ angles = np.degrees(np.arctan(np.tan(np.radians(angles)) * aspect_ratio))
501
+
502
+ rec_filter = 1
503
+ if reconstruction_filter is not None:
504
+ rec_filter = create_reconstruction_filter(
505
+ plane.shape[::-1], filter_type=reconstruction_filter, tilt_angles=angles
506
+ ).T
507
+
508
+ subset = tuple(
509
+ slice(None) if i != 0 else slice(x // 2, x // 2 + 1)
510
+ for i, x in enumerate(plane.shape)
511
+ )
512
+ plane_rotated, wedge_volume = np.zeros_like(plane), np.zeros_like(plane)
513
+ for index in range(angles.shape[0]):
514
+ plane_rotated.fill(0)
515
+ plane[subset] = 1
516
+ rotation_matrix = euler_to_rotationmatrix((angles[index], 0))
517
+ rotation_matrix = rotation_matrix[np.ix_((0, 1), (0, 1))]
518
+
519
+ NumpyFFTWBackend().rigid_transform(
520
+ arr=plane * rec_filter,
521
+ rotation_matrix=rotation_matrix,
522
+ out=plane_rotated,
523
+ use_geometric_center=True,
524
+ order=1,
525
+ )
526
+ wedge_volume += plane_rotated * weights[index]
527
+
528
+ wedge_volume = centered(wedge_volume, (shape[opening_axis], shape[tilt_axis]))
529
+ np.fmin(wedge_volume, np.max(weights), wedge_volume)
530
+
531
+ if opening_axis > tilt_axis:
532
+ wedge_volume = np.moveaxis(wedge_volume, 1, 0)
533
+
534
+ reshape_dimensions = tuple(
535
+ x if i in (opening_axis, tilt_axis) else 1 for i, x in enumerate(shape)
536
+ )
537
+
538
+ wedge_volume = wedge_volume.reshape(reshape_dimensions)
539
+ tile_dimensions = np.divide(shape, reshape_dimensions).astype(int)
540
+ wedge_volume = np.tile(wedge_volume, tile_dimensions)
541
+
542
+ return wedge_volume