pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3.0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. pytme-0.3.0.data/scripts/estimate_memory_usage.py +76 -0
  2. pytme-0.3.0.data/scripts/match_template.py +1106 -0
  3. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/postprocess.py +320 -190
  4. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocess.py +21 -31
  5. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocessor_gui.py +85 -19
  6. pytme-0.3.0.data/scripts/pytme_runner.py +771 -0
  7. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/METADATA +22 -20
  8. pytme-0.3.0.dist-info/RECORD +126 -0
  9. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/entry_points.txt +2 -1
  10. pytme-0.3.0.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +76 -0
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +224 -0
  14. scripts/match_template.py +349 -378
  15. pytme-0.2.9.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
  16. scripts/postprocess.py +320 -190
  17. scripts/preprocess.py +21 -31
  18. scripts/preprocessor_gui.py +85 -19
  19. scripts/pytme_runner.py +771 -0
  20. scripts/refine_matches.py +625 -0
  21. tests/preprocessing/test_frequency_filters.py +28 -14
  22. tests/test_analyzer.py +41 -36
  23. tests/test_backends.py +1 -0
  24. tests/test_matching_cli.py +109 -53
  25. tests/test_matching_data.py +5 -5
  26. tests/test_matching_exhaustive.py +1 -2
  27. tests/test_matching_optimization.py +4 -9
  28. tests/test_matching_utils.py +1 -1
  29. tests/test_orientations.py +0 -1
  30. tme/__version__.py +1 -1
  31. tme/analyzer/__init__.py +2 -0
  32. tme/analyzer/_utils.py +26 -21
  33. tme/analyzer/aggregation.py +396 -222
  34. tme/analyzer/base.py +127 -0
  35. tme/analyzer/peaks.py +189 -201
  36. tme/analyzer/proxy.py +123 -0
  37. tme/backends/__init__.py +4 -3
  38. tme/backends/_cupy_utils.py +25 -24
  39. tme/backends/_jax_utils.py +20 -18
  40. tme/backends/cupy_backend.py +13 -26
  41. tme/backends/jax_backend.py +24 -23
  42. tme/backends/matching_backend.py +4 -3
  43. tme/backends/mlx_backend.py +4 -3
  44. tme/backends/npfftw_backend.py +34 -30
  45. tme/backends/pytorch_backend.py +18 -4
  46. tme/cli.py +126 -0
  47. tme/density.py +9 -7
  48. tme/extensions.cpython-311-darwin.so +0 -0
  49. tme/filters/__init__.py +3 -3
  50. tme/filters/_utils.py +36 -10
  51. tme/filters/bandpass.py +229 -188
  52. tme/filters/compose.py +5 -4
  53. tme/filters/ctf.py +516 -254
  54. tme/filters/reconstruction.py +91 -32
  55. tme/filters/wedge.py +196 -135
  56. tme/filters/whitening.py +37 -42
  57. tme/matching_data.py +28 -39
  58. tme/matching_exhaustive.py +31 -27
  59. tme/matching_optimization.py +5 -4
  60. tme/matching_scores.py +25 -15
  61. tme/matching_utils.py +158 -28
  62. tme/memory.py +4 -3
  63. tme/orientations.py +22 -9
  64. tme/parser.py +114 -33
  65. tme/preprocessor.py +6 -5
  66. tme/rotations.py +10 -7
  67. tme/structure.py +4 -3
  68. pytme-0.2.9.data/scripts/estimate_ram_usage.py +0 -97
  69. pytme-0.2.9.dist-info/RECORD +0 -119
  70. pytme-0.2.9.dist-info/licenses/LICENSE +0 -153
  71. scripts/estimate_ram_usage.py +0 -97
  72. tests/data/Maps/.DS_Store +0 -0
  73. tests/data/Structures/.DS_Store +0 -0
  74. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/WHEEL +0 -0
  75. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,9 @@
1
- """ Defines filters on tomographic tilt series.
1
+ """
2
+ Defines filters on tomographic tilt series.
2
3
 
3
- Copyright (c) 2024 European Molecular Biology Laboratory
4
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  from typing import Tuple
@@ -10,32 +11,31 @@ from dataclasses import dataclass
10
11
 
11
12
  import numpy as np
12
13
 
13
- from ..types import NDArray
14
14
  from ..backends import backend as be
15
+ from ..types import NDArray, BackendArray
15
16
 
16
17
  from .compose import ComposableFilter
17
18
  from ..rotations import euler_to_rotationmatrix
18
- from ._utils import (
19
- crop_real_fourier,
20
- shift_fourier,
21
- create_reconstruction_filter,
22
- )
19
+ from ._utils import crop_real_fourier, shift_fourier, create_reconstruction_filter
23
20
 
24
- __all__ = ["ReconstructFromTilt"]
21
+ __all__ = ["ReconstructFromTilt", "ShiftFourier"]
25
22
 
26
23
 
27
24
  @dataclass
28
25
  class ReconstructFromTilt(ComposableFilter):
29
- """Reconstruct a volume from a tilt series."""
26
+ """
27
+ Reconstruct a d+1 array from a d-dimensional input projection using weighted
28
+ backprojection (WBP).
29
+ """
30
30
 
31
31
  #: Shape of the reconstruction.
32
32
  shape: Tuple[int] = None
33
33
  #: Angle of each individual tilt.
34
34
  angles: Tuple[float] = None
35
- #: The axis around which the volume is opened.
36
- opening_axis: int = 0
37
- #: Axis the plane is tilted over.
38
- tilt_axis: int = 2
35
+ #: Projection axis, defaults to 2 (z).
36
+ opening_axis: int = 2
37
+ #: Tilt axis, defaults to 0 (x).
38
+ tilt_axis: int = 0
39
39
  #: Whether to return a share compliant with rfftn.
40
40
  return_real_fourier: bool = True
41
41
  #: Interpolation order used for rotation
@@ -43,19 +43,60 @@ class ReconstructFromTilt(ComposableFilter):
43
43
  #: Filter window applied during reconstruction.
44
44
  reconstruction_filter: str = None
45
45
 
46
- def __call__(self, **kwargs):
46
+ def __call__(self, return_real_fourier: bool = False, **kwargs):
47
+ """
48
+ Reconstruct a d+1 array from a d-dimensional input using WBP.
49
+
50
+ Parameters
51
+ ----------
52
+ shape : tuple of int
53
+ The shape of the reconstruction volume.
54
+ data : BackendArray
55
+ D-dimensional image stack with shape (n, ...). The data is assumed to be
56
+ a Fourier transform of the stack you are trying to reconstruct with
57
+ DC component in the center.
58
+ angles : tuple of float
59
+ Angle of each individual tilt.
60
+ return_real_fourier : bool, optional
61
+ Return a shape compliant
62
+ return_real_fourier : tuple of int
63
+ Return a shape compliant with rfft, i.e., omit the negative frequencies
64
+ terms resulting in a return shape (*shape[:-1], shape[-1]//2+1). Defaults
65
+ to False.
66
+ reconstruction_filter : bool, optional
67
+ Filter window applied during reconstruction.
68
+ See :py:meth:`create_reconstruction_filter` for available options.
69
+ tilt_axis : int
70
+ Axis the plane is tilted over, defaults to 0 (x).
71
+ opening_axis : int
72
+ The projection axis, defaults to 2 (z).
73
+
74
+ Returns
75
+ -------
76
+ dict
77
+ data: BackendArray
78
+ The filter mask.
79
+ shape: tuple of ints
80
+ The requested filter shape
81
+ return_real_fourier: bool
82
+ Whether data is compliant with rfftn.
83
+ is_multiplicative_filter: bool
84
+ Whether the filter is multiplicative in Fourier space.
85
+ """
86
+
47
87
  func_args = vars(self).copy()
48
88
  func_args.update(kwargs)
49
89
 
50
90
  ret = self.reconstruct(**func_args)
51
91
 
92
+ ret = shift_fourier(data=ret, shape_is_real_fourier=False)
93
+ if return_real_fourier:
94
+ ret = crop_real_fourier(ret)
95
+
52
96
  return {
53
97
  "data": ret,
54
- "shape": ret.shape,
55
- "shape_is_real_fourier": func_args["return_real_fourier"],
56
- "angles": func_args["angles"],
57
- "tilt_axis": func_args["tilt_axis"],
58
- "opening_axis": func_args["opening_axis"],
98
+ "shape": func_args["shape"],
99
+ "return_real_fourier": return_real_fourier,
59
100
  "is_multiplicative_filter": False,
60
101
  }
61
102
 
@@ -67,7 +108,6 @@ class ReconstructFromTilt(ComposableFilter):
67
108
  opening_axis: int,
68
109
  tilt_axis: int,
69
110
  interpolation_order: int = 1,
70
- return_real_fourier: bool = True,
71
111
  reconstruction_filter: str = None,
72
112
  **kwargs,
73
113
  ):
@@ -77,7 +117,7 @@ class ReconstructFromTilt(ComposableFilter):
77
117
  Parameters
78
118
  ----------
79
119
  data : NDArray
80
- The tilt series data.
120
+ The Fourier transform of tilt series data.
81
121
  shape : tuple of int
82
122
  Shape of the reconstruction.
83
123
  angles : tuple of float
@@ -88,8 +128,6 @@ class ReconstructFromTilt(ComposableFilter):
88
128
  Axis the plane is tilted over.
89
129
  interpolation_order : int, optional
90
130
  Interpolation order used for rotation, defaults to 1.
91
- return_real_fourier : bool, optional
92
- Whether to return a shape compliant with rfftn, defaults to True.
93
131
  reconstruction_filter : bool, optional
94
132
  Filter window applied during reconstruction.
95
133
  See :py:meth:`create_reconstruction_filter` for available options.
@@ -103,9 +141,9 @@ class ReconstructFromTilt(ComposableFilter):
103
141
  return data
104
142
 
105
143
  data = be.to_backend_array(data)
106
- volume_temp = be.zeros(shape, dtype=be._float_dtype)
107
- volume_temp_rotated = be.zeros(shape, dtype=be._float_dtype)
108
- volume = be.zeros(shape, dtype=be._float_dtype)
144
+ volume_temp = be.zeros(shape, dtype=data.dtype)
145
+ volume_temp_rotated = be.zeros(shape, dtype=data.dtype)
146
+ volume = be.zeros(shape, dtype=data.dtype)
109
147
 
110
148
  slices = tuple(slice(a // 2, (a // 2) + 1) for a in shape)
111
149
  subset = tuple(
@@ -152,9 +190,30 @@ class ReconstructFromTilt(ComposableFilter):
152
190
  )
153
191
  volume = be.add(volume, volume_temp_rotated, out=volume)
154
192
 
155
- volume = shift_fourier(data=volume, shape_is_real_fourier=False)
193
+ return volume
156
194
 
157
- if return_real_fourier:
158
- volume = crop_real_fourier(volume)
159
195
 
160
- return volume
196
+ class ShiftFourier(ComposableFilter):
197
+ def __call__(
198
+ self,
199
+ data: BackendArray,
200
+ shape_is_real_fourier: bool = False,
201
+ return_real_fourier: bool = True,
202
+ **kwargs,
203
+ ):
204
+ ret = []
205
+ for index in range(data.shape[0]):
206
+ mask = be.to_numpy_array(data[index])
207
+
208
+ mask = shift_fourier(data=mask, shape_is_real_fourier=shape_is_real_fourier)
209
+ if return_real_fourier:
210
+ mask = crop_real_fourier(mask)
211
+ ret.append(mask[None])
212
+ ret = np.concatenate(ret, axis=0)
213
+
214
+ return {
215
+ "data": ret,
216
+ "shape": kwargs.get("shape"),
217
+ "return_real_fourier": return_real_fourier,
218
+ "is_multiplicative_filter": False,
219
+ }
tme/filters/wedge.py CHANGED
@@ -1,12 +1,14 @@
1
- """ Implements class Wedge and WedgeReconstructed to create Fourier
2
- filter representations.
1
+ """
2
+ Implements class Wedge and WedgeReconstructed to create Fourier
3
+ filter representations.
3
4
 
4
- Copyright (c) 2024 European Molecular Biology Laboratory
5
+ Copyright (c) 2024 European Molecular Biology Laboratory
5
6
 
6
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
8
  """
8
9
 
9
10
  from typing import Tuple, Dict
11
+ from dataclasses import dataclass
10
12
 
11
13
  import numpy as np
12
14
 
@@ -15,6 +17,7 @@ from ..backends import backend as be
15
17
  from .compose import ComposableFilter
16
18
  from ..matching_utils import centered
17
19
  from ..rotations import euler_to_rotationmatrix
20
+ from ..parser import XMLParser, StarParser, MDOCParser
18
21
  from ._utils import (
19
22
  centered_grid,
20
23
  frequency_grid_at_angle,
@@ -28,55 +31,44 @@ from ._utils import (
28
31
  __all__ = ["Wedge", "WedgeReconstructed"]
29
32
 
30
33
 
34
+ @dataclass
31
35
  class Wedge(ComposableFilter):
32
36
  """
33
37
  Generate wedge mask for tomographic data.
34
-
35
- Parameters
36
- ----------
37
- shape : tuple of int
38
- The shape of the reconstruction volume.
39
- tilt_axis : int
40
- Axis the plane is tilted over.
41
- opening_axis : int
42
- The axis around which the volume is opened.
43
- angles : tuple of float
44
- The tilt angles.
45
- weights : tuple of float
46
- The weights corresponding to each tilt angle.
47
- weight_type : str, optional
48
- The type of weighting to apply, defaults to None.
49
- frequency_cutoff : float, optional
50
- Frequency cutoff for created mask. Nyquist 0.5 by default.
51
-
52
- Returns
53
- -------
54
- Dict
55
- A dictionary containing weighted wedges and related information.
56
38
  """
57
39
 
58
- def __init__(
59
- self,
60
- shape: Tuple[int],
61
- tilt_axis: int,
62
- opening_axis: int,
63
- angles: Tuple[float],
64
- weights: Tuple[float],
65
- weight_type: str = None,
66
- frequency_cutoff: float = 0.5,
67
- ):
68
- self.shape = shape
69
- self.tilt_axis = tilt_axis
70
- self.opening_axis = opening_axis
71
- self.angles = angles
72
- self.weights = weights
73
- self.frequency_cutoff = frequency_cutoff
40
+ #: The shape of the reconstruction volume.
41
+ shape: Tuple[int] = None
42
+ #: The tilt angles.
43
+ angles: Tuple[float] = None
44
+ #: The weights corresponding to each tilt angle.
45
+ weights: Tuple[float] = None
46
+ #: Axis the plane is tilted over, defaults to 0 (x).
47
+ tilt_axis: int = 0
48
+ #: The projection axis, defaults to 2 (z).
49
+ opening_axis: int = 2
50
+ #: The type of weighting to apply, defaults to None.
51
+ weight_type: str = None
52
+ #: Frequency cutoff for created mask. Nyquist 0.5 by default.
53
+ frequency_cutoff: float = 0.5
54
+ #: The sampling rate, defaults to 1 Ångstrom / voxel.
55
+ sampling_rate: Tuple[float] = 1
74
56
 
75
57
  @classmethod
76
58
  def from_file(cls, filename: str) -> "Wedge":
77
59
  """
78
- Generate a :py:class:`Wedge` instance by reading tilt angles and weights
79
- from a tab-separated text file.
60
+ Generate a :py:class:`Wedge` instance by reading tilt angles and weights.
61
+ Supported extensions are:
62
+
63
+ +-------+---------------------------------------------------------+
64
+ | .star | Tomostar STAR file |
65
+ +-------+---------------------------------------------------------+
66
+ | .xml | WARP/M XML file |
67
+ +-------+---------------------------------------------------------+
68
+ | .mdoc | SerialEM file |
69
+ +-------+---------------------------------------------------------+
70
+ | .* | Tab-separated file with optional column names |
71
+ +-------+---------------------------------------------------------+
80
72
 
81
73
  Parameters
82
74
  ----------
@@ -88,8 +80,15 @@ class Wedge(ComposableFilter):
88
80
  :py:class:`Wedge`
89
81
  Class instance instance initialized with angles and weights from the file.
90
82
  """
91
- data = cls._from_text(filename)
92
-
83
+ func = _from_text
84
+ if filename.lower().endswith("xml"):
85
+ func = _from_xml
86
+ elif filename.lower().endswith("star"):
87
+ func = _from_star
88
+ elif filename.lower().endswith("mdoc"):
89
+ func = _from_mdoc
90
+
91
+ data = func(filename)
93
92
  angles, weights = data.get("angles", None), data.get("weights", None)
94
93
  if angles is None:
95
94
  raise ValueError(f"Could not find colum angles in {filename}")
@@ -108,33 +107,10 @@ class Wedge(ComposableFilter):
108
107
  weights=np.array(weights, dtype=np.float32),
109
108
  )
110
109
 
111
- @staticmethod
112
- def _from_text(filename: str, delimiter="\t") -> Dict:
110
+ def __call__(self, **kwargs: Dict) -> NDArray:
113
111
  """
114
- Read column data from a text file.
115
-
116
- Parameters
117
- ----------
118
- filename : str
119
- The path to the text file.
120
- delimiter : str, optional
121
- The delimiter used in the file, defaults to '\t'.
122
-
123
- Returns
124
- -------
125
- Dict
126
- A dictionary with one key for each column.
112
+ Returns a Wedge stack of chosen parameters with DC component in the center.
127
113
  """
128
- with open(filename, mode="r", encoding="utf-8") as infile:
129
- data = [x.strip() for x in infile.read().split("\n")]
130
- data = [x.split("\t") for x in data if len(x)]
131
-
132
- headers = data.pop(0)
133
- ret = {header: list(column) for header, column in zip(headers, zip(*data))}
134
-
135
- return ret
136
-
137
- def __call__(self, **kwargs: Dict) -> NDArray:
138
114
  func_args = vars(self).copy()
139
115
  func_args.update(kwargs)
140
116
 
@@ -172,10 +148,8 @@ class Wedge(ComposableFilter):
172
148
 
173
149
  return {
174
150
  "data": ret,
175
- "angles": func_args["angles"],
176
- "tilt_axis": func_args["tilt_axis"],
177
- "opening_axis": func_args["opening_axis"],
178
- "sampling_rate": func_args.get("sampling_rate", 1),
151
+ "shape": func_args["shape"],
152
+ "return_real_fourier": func_args.get("return_real_fourier", False),
179
153
  "is_multiplicative_filter": True,
180
154
  }
181
155
 
@@ -202,7 +176,12 @@ class Wedge(ComposableFilter):
202
176
  return wedges
203
177
 
204
178
  def weight_relion(
205
- self, shape: Tuple[int], opening_axis: int, tilt_axis: int, **kwargs
179
+ self,
180
+ shape: Tuple[int],
181
+ opening_axis: int,
182
+ tilt_axis: int,
183
+ sampling_rate: float = 1.0,
184
+ **kwargs,
206
185
  ) -> NDArray:
207
186
  """
208
187
  Generate weighted wedges based on the RELION 1.4 formalism, weighting each
@@ -217,7 +196,6 @@ class Wedge(ComposableFilter):
217
196
  tilt_shape = compute_tilt_shape(
218
197
  shape=shape, opening_axis=opening_axis, reduce_dim=True
219
198
  )
220
-
221
199
  wedges = np.zeros((len(self.angles), *tilt_shape))
222
200
  for index, angle in enumerate(self.angles):
223
201
  frequency_grid = frequency_grid_at_angle(
@@ -225,7 +203,7 @@ class Wedge(ComposableFilter):
225
203
  opening_axis=opening_axis,
226
204
  tilt_axis=tilt_axis,
227
205
  angle=angle,
228
- sampling_rate=1,
206
+ sampling_rate=sampling_rate,
229
207
  )
230
208
  sigma = np.sqrt(self.weights[index] * 4 / (8 * np.pi**2))
231
209
  sigma = -2 * np.pi**2 * sigma**2
@@ -245,6 +223,7 @@ class Wedge(ComposableFilter):
245
223
  amplitude: float = 0.245,
246
224
  power: float = -1.665,
247
225
  offset: float = 2.81,
226
+ sampling_rate: float = 1.0,
248
227
  **kwargs,
249
228
  ) -> NDArray:
250
229
  """
@@ -270,7 +249,7 @@ class Wedge(ComposableFilter):
270
249
  opening_axis=opening_axis,
271
250
  tilt_axis=tilt_axis,
272
251
  angle=angle,
273
- sampling_rate=1,
252
+ sampling_rate=sampling_rate,
274
253
  )
275
254
 
276
255
  with np.errstate(divide="ignore"):
@@ -289,54 +268,36 @@ class Wedge(ComposableFilter):
289
268
  return wedges
290
269
 
291
270
 
271
+ @dataclass
292
272
  class WedgeReconstructed:
293
273
  """
294
274
  Initialize :py:class:`WedgeReconstructed`.
295
-
296
- Parameters
297
- ----------
298
- angles :tuple of float, optional
299
- The tilt angles, defaults to None.
300
- opening_axis : int, optional
301
- The axis around which the wedge is opened.
302
- tilt_axis : int, optional
303
- The axis along which the tilt is applied.
304
- weights : tuple of float, optional
305
- Weights to assign to individual wedge components.
306
- weight_wedge : bool, optional
307
- Whether individual wedge components should be weighted. If True and weights
308
- is None, uses the cosine of the angle otherwise weights.
309
- create_continuous_wedge: bool, optional
310
- Whether to create a continous wedge or a per-component wedge. Weights are only
311
- considered for non-continuous wedges.
312
- frequency_cutoff : float, optional
313
- Filter window applied during reconstruction.
314
- **kwargs : Dict
315
- Additional keyword arguments.
316
275
  """
317
276
 
318
- def __init__(
319
- self,
320
- opening_axis: int,
321
- tilt_axis: int,
322
- angles: Tuple[float] = None,
323
- weights: Tuple[float] = None,
324
- weight_wedge: bool = False,
325
- create_continuous_wedge: bool = False,
326
- frequency_cutoff: float = 0.5,
327
- reconstruction_filter: str = None,
328
- **kwargs: Dict,
329
- ):
330
- self.angles = angles
331
- self.opening_axis = opening_axis
332
- self.tilt_axis = tilt_axis
333
- self.weights = weights
334
- self.weight_wedge = weight_wedge
335
- self.reconstruction_filter = reconstruction_filter
336
- self.create_continuous_wedge = create_continuous_wedge
337
- self.frequency_cutoff = frequency_cutoff
338
-
339
- def __call__(self, shape: Tuple[int], **kwargs: Dict) -> Dict:
277
+ #: The tilt angles, defaults to None.
278
+ angles: Tuple[float] = None
279
+ #: Weights to assign to individual wedge components. Not considered for continuous wedge
280
+ weights: Tuple[float] = None
281
+ #: Whether individual wedge components should be weighted.
282
+ weight_wedge: bool = False
283
+ #: Whether to create a continous wedge or a per-component wedge.
284
+ create_continuous_wedge: bool = False
285
+ #: Frequency cutoff of filter
286
+ frequency_cutoff: float = 0.5
287
+ #: Axis the plane is tilted over, defaults to 0 (x).
288
+ tilt_axis: int = 0
289
+ #: The projection axis, defaults to 2 (z).
290
+ opening_axis: int = 2
291
+ #: Filter window applied during reconstruction.
292
+ reconstruction_filter: str = None
293
+
294
+ def __post_init__(self):
295
+ if self.create_continuous_wedge:
296
+ self.angles = (min(self.angles), max(self.angles))
297
+
298
+ def __call__(
299
+ self, shape: Tuple[int], return_real_fourier: bool = False, **kwargs
300
+ ) -> Dict:
340
301
  """
341
302
  Generate the reconstructed wedge.
342
303
 
@@ -344,20 +305,26 @@ class WedgeReconstructed:
344
305
  ----------
345
306
  shape : tuple of int
346
307
  The shape of the reconstruction volume.
347
- **kwargs : Dict
308
+ return_real_fourier : tuple of int
309
+ Return a shape compliant with rfftn. Defaults to False.
310
+ **kwargs : dict
348
311
  Additional keyword arguments.
349
312
 
350
313
  Returns
351
314
  -------
352
- Dict
353
- A dictionary containing the reconstructed wedge and related information.
315
+ dict
316
+ data: BackendArray
317
+ The filter mask.
318
+ shape: tuple of ints
319
+ The requested filter shape
320
+ return_real_fourier: bool
321
+ Whether data is compliant with rfftn.
322
+ is_multiplicative_filter: bool
323
+ Whether the filter is multiplicative in Fourier space.
354
324
  """
355
325
  func_args = vars(self).copy()
356
326
  func_args.update(kwargs)
357
327
 
358
- if kwargs.get("is_fourier_shape", False):
359
- print("Cannot create continuous wedge mask based on real fourier shape.")
360
-
361
328
  func = self.step_wedge
362
329
  if func_args.get("create_continuous_wedge", False):
363
330
  func = self.continuous_wedge
@@ -369,7 +336,6 @@ class WedgeReconstructed:
369
336
  )
370
337
 
371
338
  ret = func(shape=shape, **func_args)
372
-
373
339
  frequency_cutoff = func_args.get("frequency_cutoff", None)
374
340
  if frequency_cutoff is not None:
375
341
  frequency_mask = fftfreqn(
@@ -386,17 +352,15 @@ class WedgeReconstructed:
386
352
  ret = be.astype(be.to_backend_array(ret), be._float_dtype)
387
353
 
388
354
  ret = shift_fourier(data=ret, shape_is_real_fourier=False)
389
- if func_args.get("return_real_fourier", False):
355
+
356
+ if return_real_fourier:
390
357
  ret = crop_real_fourier(ret)
391
358
 
392
359
  return {
393
360
  "data": ret,
394
- "shape_is_real_fourier": func_args["return_real_fourier"],
395
- "shape": ret.shape,
396
- "tilt_axis": func_args["tilt_axis"],
397
- "opening_axis": func_args["opening_axis"],
361
+ "shape": shape,
362
+ "return_real_fourier": return_real_fourier,
398
363
  "is_multiplicative_filter": True,
399
- "angles": func_args["angles"],
400
364
  }
401
365
 
402
366
  @staticmethod
@@ -426,6 +390,7 @@ class WedgeReconstructed:
426
390
  NDArray
427
391
  Wedge mask.
428
392
  """
393
+ angles = np.abs(np.asarray(angles))
429
394
  aspect_ratio = shape[opening_axis] / shape[tilt_axis]
430
395
  angles = np.degrees(np.arctan(np.tan(np.radians(angles)) * aspect_ratio))
431
396
 
@@ -540,3 +505,99 @@ class WedgeReconstructed:
540
505
  wedge_volume = np.tile(wedge_volume, tile_dimensions)
541
506
 
542
507
  return wedge_volume
508
+
509
+
510
+ def _from_xml(filename: str, **kwargs) -> Dict:
511
+ """
512
+ Read tilt data from a WARP/M XML file.
513
+
514
+ Parameters
515
+ ----------
516
+ filename : str
517
+ The path to the text file.
518
+
519
+ Returns
520
+ -------
521
+ Dict
522
+ A dictionary with one key for each column.
523
+ """
524
+ data = XMLParser(filename)
525
+ return {"angles": data["Angles"], "weights": data["Dose"]}
526
+
527
+
528
+ def _from_star(filename: str, **kwargs) -> Dict:
529
+ """
530
+ Read tilt data from a STAR file.
531
+
532
+ Parameters
533
+ ----------
534
+ filename : str
535
+ The path to the text file.
536
+
537
+ Returns
538
+ -------
539
+ Dict
540
+ A dictionary with one key for each column.
541
+ """
542
+ data = StarParser(filename, delimiter=None)
543
+ if "data_stopgap_wedgelist" in data:
544
+ angles = data["data_stopgap_wedgelist"]["_tilt_angle"]
545
+ weights = data["data_stopgap_wedgelist"]["_exposure"]
546
+ else:
547
+ angles = data["data_"]["_wrpAxisAngle"]
548
+ weights = data["data_"]["_wrpDose"]
549
+ return {"angles": angles, "weights": weights}
550
+
551
+
552
+ def _from_mdoc(filename: str, **kwargs) -> Dict:
553
+ """
554
+ Read tilt data from a SerialEM MDOC file.
555
+
556
+ Parameters
557
+ ----------
558
+ filename : str
559
+ The path to the text file.
560
+
561
+ Returns
562
+ -------
563
+ Dict
564
+ A dictionary with one key for each column.
565
+ """
566
+ data = MDOCParser(filename)
567
+ cumulative_exposure = np.multiply(np.add(1, data["ZValue"]), data["ExposureDose"])
568
+ return {"angles": data["TiltAngle"], "weights": cumulative_exposure}
569
+
570
+
571
+ def _from_text(filename: str, **kwargs) -> Dict:
572
+ """
573
+ Read column data from a text file.
574
+
575
+ Parameters
576
+ ----------
577
+ filename : str
578
+ The path to the text file.
579
+ delimiter : str, optional
580
+ The delimiter used in the file, defaults to '\t'.
581
+
582
+ Returns
583
+ -------
584
+ Dict
585
+ A dictionary with one key for each column.
586
+ """
587
+ with open(filename, mode="r", encoding="utf-8") as infile:
588
+ data = [x.strip() for x in infile.read().split("\n")]
589
+ data = [x.split("\t") for x in data if len(x)]
590
+
591
+ if "angles" in data[0]:
592
+ headers = data.pop(0)
593
+ else:
594
+ if len(data[0]) != 1:
595
+ raise ValueError(
596
+ "Found more than one column without column names. Please add "
597
+ "column names to your file. If you only want to specify tilt "
598
+ "angles without column names, use a single column file."
599
+ )
600
+ headers = ("angles",)
601
+ ret = {header: list(column) for header, column in zip(headers, zip(*data))}
602
+
603
+ return ret