pytme 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2.dev0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pytme-0.3.2.dev0.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/match_template.py +213 -196
  3. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/postprocess.py +40 -78
  4. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/preprocess.py +4 -5
  5. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/preprocessor_gui.py +50 -103
  6. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/pytme_runner.py +46 -69
  7. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/METADATA +2 -1
  8. pytme-0.3.2.dev0.dist-info/RECORD +136 -0
  9. scripts/estimate_ram_usage.py +97 -0
  10. scripts/match_template.py +213 -196
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +40 -78
  13. scripts/preprocess.py +4 -5
  14. scripts/preprocessor_gui.py +50 -103
  15. scripts/pytme_runner.py +46 -69
  16. scripts/refine_matches.py +5 -7
  17. tests/preprocessing/test_compose.py +31 -30
  18. tests/preprocessing/test_frequency_filters.py +17 -32
  19. tests/preprocessing/test_preprocessor.py +0 -19
  20. tests/preprocessing/test_utils.py +13 -1
  21. tests/test_analyzer.py +2 -10
  22. tests/test_backends.py +47 -18
  23. tests/test_density.py +72 -13
  24. tests/test_extensions.py +1 -0
  25. tests/test_matching_cli.py +23 -9
  26. tests/test_matching_exhaustive.py +5 -5
  27. tests/test_matching_utils.py +3 -3
  28. tests/test_rotations.py +13 -23
  29. tests/test_structure.py +1 -7
  30. tme/__version__.py +1 -1
  31. tme/analyzer/aggregation.py +47 -16
  32. tme/analyzer/base.py +34 -0
  33. tme/analyzer/peaks.py +26 -13
  34. tme/analyzer/proxy.py +14 -0
  35. tme/backends/_jax_utils.py +124 -71
  36. tme/backends/cupy_backend.py +6 -19
  37. tme/backends/jax_backend.py +110 -105
  38. tme/backends/matching_backend.py +0 -17
  39. tme/backends/mlx_backend.py +0 -29
  40. tme/backends/npfftw_backend.py +100 -97
  41. tme/backends/pytorch_backend.py +65 -78
  42. tme/cli.py +2 -2
  43. tme/density.py +102 -58
  44. tme/extensions.cpython-311-darwin.so +0 -0
  45. tme/filters/_utils.py +52 -24
  46. tme/filters/bandpass.py +99 -105
  47. tme/filters/compose.py +133 -39
  48. tme/filters/ctf.py +51 -102
  49. tme/filters/reconstruction.py +67 -122
  50. tme/filters/wedge.py +296 -325
  51. tme/filters/whitening.py +39 -75
  52. tme/mask.py +2 -2
  53. tme/matching_data.py +87 -15
  54. tme/matching_exhaustive.py +70 -120
  55. tme/matching_optimization.py +9 -63
  56. tme/matching_scores.py +261 -100
  57. tme/matching_utils.py +150 -91
  58. tme/memory.py +1 -0
  59. tme/orientations.py +28 -8
  60. tme/preprocessor.py +0 -239
  61. tme/rotations.py +102 -70
  62. tme/structure.py +601 -631
  63. tme/types.py +1 -0
  64. pytme-0.3.1.post1.dist-info/RECORD +0 -133
  65. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/estimate_memory_usage.py +0 -0
  66. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/WHEEL +0 -0
  67. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/entry_points.txt +0 -0
  68. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/licenses/LICENSE +0 -0
  69. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/top_level.txt +0 -0
tme/density.py CHANGED
@@ -16,7 +16,6 @@ from os.path import splitext, basename
16
16
  import h5py
17
17
  import mrcfile
18
18
  import numpy as np
19
- import skimage.io as skio
20
19
 
21
20
  from scipy.ndimage import (
22
21
  zoom,
@@ -26,7 +25,6 @@ from scipy.ndimage import (
26
25
  binary_erosion,
27
26
  generic_gradient_magnitude,
28
27
  )
29
- from scipy.spatial import ConvexHull
30
28
 
31
29
  from .types import NDArray
32
30
  from .rotations import align_to_axis
@@ -571,6 +569,8 @@ class Density:
571
569
  --------
572
570
  :py:meth:`Density.from_file`
573
571
  """
572
+ import skimage.io as skio
573
+
574
574
  swap = filename
575
575
  if is_gzipped(filename):
576
576
  with gzip_open(filename, "rb") as infile:
@@ -938,6 +938,8 @@ class Density:
938
938
  ----------
939
939
  .. [1] https://scikit-image.org/docs/stable/api/skimage.io.html
940
940
  """
941
+ import skimage.io as skio
942
+
941
943
  swap, kwargs = filename, {}
942
944
  if gzip:
943
945
  swap = BytesIO()
@@ -1403,8 +1405,7 @@ class Density:
1403
1405
  cutoff : float
1404
1406
  Above this value arr elements are considered. Defaults to 0.
1405
1407
  use_geometric_center : bool, optional
1406
- Whether the box should accommodate the geometric or the coordinate
1407
- center. Defaults to False.
1408
+ Accommodate the geometric instead of the mass center.
1408
1409
 
1409
1410
  Returns
1410
1411
  -------
@@ -1416,24 +1417,25 @@ class Density:
1416
1417
  :py:meth:`Density.adjust_box`
1417
1418
  :py:meth:`tme.matching_utils.minimum_enclosing_box`
1418
1419
  """
1419
- coordinates = self.to_pointcloud(threshold=cutoff)
1420
- starts, stops = coordinates.min(axis=1), coordinates.max(axis=1)
1420
+ if cutoff is None:
1421
+ cutoff = self.data.min() - 1
1421
1422
 
1423
+ coordinates = self.to_pointcloud(threshold=cutoff)
1422
1424
  shape = minimum_enclosing_box(
1423
1425
  coordinates=coordinates,
1424
1426
  use_geometric_center=use_geometric_center,
1425
1427
  )
1426
- difference = np.maximum(np.subtract(shape, np.subtract(stops, starts)), 0)
1427
1428
 
1428
- shift_start = np.divide(difference, 2).astype(int)
1429
- shift_stop = shift_start + np.mod(difference, 2)
1429
+ starts, stops = coordinates.min(axis=1), coordinates.max(axis=1)
1430
+ diff = np.maximum(np.subtract(shape, np.subtract(stops, starts)), 0)
1430
1431
 
1432
+ shift_start = np.divide(diff, 2).astype(int)
1433
+ shift_stop = shift_start + np.mod(diff, 2)
1434
+
1435
+ # These are purposefully negative to indicate left and right pad
1431
1436
  starts = (starts - shift_start).astype(int)
1432
1437
  stops = (stops + shift_stop).astype(int)
1433
-
1434
- enclosing_box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
1435
-
1436
- return tuple(enclosing_box)
1438
+ return tuple(slice(start, stop) for start, stop in zip(starts, stops))
1437
1439
 
1438
1440
  def pad(
1439
1441
  self, new_shape: Tuple[int], center: bool = True, padding_value: float = 0
@@ -1500,7 +1502,9 @@ class Density:
1500
1502
 
1501
1503
  self.adjust_box(new_box, pad_kwargs={"constant_values": padding_value})
1502
1504
 
1503
- def centered(self, cutoff: float = 0) -> Tuple["Density", NDArray]:
1505
+ def centered(
1506
+ self, cutoff: float = 0.0, order: int = 1
1507
+ ) -> Tuple["Density", NDArray]:
1504
1508
  """
1505
1509
  Shifts the data center of mass to the center of the data array using linear
1506
1510
  interpolation. The box size of the returned :py:class:`Density` object is at
@@ -1509,8 +1513,10 @@ class Density:
1509
1513
  Parameters
1510
1514
  ----------
1511
1515
  cutoff : float, optional
1512
- Only elements in data larger than cutoff will be considered for
1513
- computing the new box. By default considers only positive elements.
1516
+ Consider all elements larger than cutoff for computing the new box.
1517
+ Default is 0.0.
1518
+ order : int, optional
1519
+ Interpolation order, defaults to 1.
1514
1520
 
1515
1521
  Notes
1516
1522
  -----
@@ -1534,16 +1540,15 @@ class Density:
1534
1540
  Examples
1535
1541
  --------
1536
1542
  :py:meth:`Density.centered` returns a tuple containing a centered version
1537
- of the current :py:class:`Density` instance, as well as an array with
1538
- translations. The translation corresponds to the shift between the original and
1539
- current center of mass.
1543
+ of the current :py:class:`Density` instance. Centering is achieved via padding
1544
+ and rigid-transform of the internal :py:attr:`Density.data` attribute.
1545
+ `centered_dens` is sufficiently large to represent all rotations of the
1546
+ :py:attr:`Density.data` attribute.
1540
1547
 
1541
1548
  >>> import numpy as np
1542
1549
  >>> from tme import Density
1543
1550
  >>> dens = Density(np.ones((5,5,5)))
1544
- >>> centered_dens, translation = dens.centered(0)
1545
- >>> translation
1546
- array([0., 0., 0.])
1551
+ >>> centered_dens = dens.centered(0)
1547
1552
 
1548
1553
  :py:meth:`Density.centered` extended the :py:attr:`Density.data` attribute
1549
1554
  of the current :py:class:`Density` instance and modified
@@ -1551,22 +1556,6 @@ class Density:
1551
1556
 
1552
1557
  >>> centered_dens
1553
1558
  Origin: (-2.0, -2.0, -2.0), sampling_rate: (1, 1, 1), Shape: (9, 9, 9)
1554
-
1555
- :py:meth:`Density.centered` achieves centering via zero-padding and
1556
- rigid-transform of the internal :py:attr:`Density.data` attribute.
1557
- `centered_dens` is sufficiently large to represent all rotations of the
1558
- :py:attr:`Density.data` attribute, such as ones obtained from
1559
- :py:meth:`tme.matching_utils.get_rotation_matrices`.
1560
-
1561
- >>> from tme.matching_utils import get_rotation_matrices
1562
- >>> rotation_matrix = get_rotation_matrices(dim = 3 ,angular_sampling = 10)[0]
1563
- >>> rotated_centered_dens = centered_dens.rigid_transform(
1564
- >>> rotation_matrix = rotation_matrix,
1565
- >>> order = None
1566
- >>> )
1567
- >>> print(centered_dens.data.sum(), rotated_centered_dens.data.sum())
1568
- 125.0 125.0
1569
-
1570
1559
  """
1571
1560
  ret = self.copy()
1572
1561
 
@@ -1583,12 +1572,11 @@ class Density:
1583
1572
  ret = ret.rigid_transform(
1584
1573
  translation=shift,
1585
1574
  rotation_matrix=np.eye(ret.data.ndim),
1586
- use_geometric_center=False,
1587
- order=1,
1575
+ use_geometric_center=True,
1576
+ order=order,
1588
1577
  )
1589
-
1590
- shift = np.subtract(center, self.center_of_mass(ret.data, cutoff))
1591
- return ret, shift
1578
+ ret.origin = np.subtract(ret.origin, np.multiply(shift, ret.sampling_rate))
1579
+ return ret
1592
1580
 
1593
1581
  def rigid_transform(
1594
1582
  self,
@@ -1895,6 +1883,8 @@ class Density:
1895
1883
 
1896
1884
  lower_bound, upper_bound = density_boundaries
1897
1885
  if method == "ConvexHull":
1886
+ from scipy.spatial import ConvexHull
1887
+
1898
1888
  binary = np.transpose(np.where(self.data > lower_bound))
1899
1889
  hull = ConvexHull(binary)
1900
1890
  surface_points = binary[hull.vertices[:]]
@@ -2032,27 +2022,24 @@ class Density:
2032
2022
  eroded_mask = binary_erosion(eroded_mask)
2033
2023
  return core_indices
2034
2024
 
2035
- @staticmethod
2036
- def center_of_mass(arr: NDArray, cutoff: float = None) -> NDArray:
2025
+ def center_of_mass(self, arr: NDArray = None, cutoff: float = None) -> NDArray:
2037
2026
  """
2038
- Computes the center of mass of a numpy ndarray instance using all available
2039
- elements. For template matching it typically makes sense to only input
2040
- positive densities.
2027
+ Computes the center of mass of a numpy ndarray instance.
2041
2028
 
2042
2029
  Parameters
2043
2030
  ----------
2044
- arr : NDArray
2045
- Array to compute the center of mass of.
2031
+ arr : NDArray, optional
2032
+ Array to compute the center of mass of, default is :py:attr:`Density.data`.
2046
2033
  cutoff : float, optional
2047
- Densities less than or equal to cutoff are nullified for center
2048
- of mass computation. By default considers all values.
2034
+ Density cutoff for calculation. Defaults to None
2049
2035
 
2050
2036
  Returns
2051
2037
  -------
2052
2038
  NDArray
2053
2039
  Center of mass with shape (arr.ndim).
2054
2040
  """
2055
- return NumpyFFTWBackend().center_of_mass(arr**2, cutoff)
2041
+ arr = self.data if arr is None else arr
2042
+ return NumpyFFTWBackend().center_of_mass(arr, cutoff)
2056
2043
 
2057
2044
  @classmethod
2058
2045
  def match_densities(
@@ -2101,13 +2088,13 @@ class Density:
2101
2088
  -----
2102
2089
  No densities below cutoff_template are present in the returned Density object.
2103
2090
  """
2104
- from .matching_utils import normalize_template
2091
+ from .matching_utils import standardize
2105
2092
  from .matching_optimization import optimize_match, create_score_object
2106
2093
 
2107
2094
  template_mask = template.empty
2108
2095
  template_mask.data.fill(1)
2109
2096
 
2110
- normalize_template(
2097
+ template.data = standardize(
2111
2098
  template=template.data,
2112
2099
  mask=template_mask.data,
2113
2100
  n_observations=template_mask.data.sum(),
@@ -2141,8 +2128,8 @@ class Density:
2141
2128
  template_coordinates = template_coordinates * template_scaling[:, None]
2142
2129
 
2143
2130
  mass_center_difference = np.subtract(
2144
- cls.center_of_mass(target.data, cutoff_target),
2145
- cls.center_of_mass(template.data, cutoff_template),
2131
+ target.center_of_mass(target.data, cutoff_target),
2132
+ target.center_of_mass(template.data, cutoff_template),
2146
2133
  ).astype(int)
2147
2134
  template_coordinates += mass_center_difference[:, None]
2148
2135
 
@@ -2196,7 +2183,7 @@ class Density:
2196
2183
 
2197
2184
  Parameters
2198
2185
  ----------
2199
- target : Density
2186
+ target : :py:class:`Density`
2200
2187
  The target map for template matching.
2201
2188
  template : Structure
2202
2189
  The template that should be aligned to the target.
@@ -2259,3 +2246,60 @@ class Density:
2259
2246
  coordinates = np.array(np.where(data > 0))
2260
2247
  weights = self.data[tuple(coordinates)]
2261
2248
  return align_to_axis(coordinates.T, weights=weights, axis=axis, flip=flip)
2249
+
2250
+ @staticmethod
2251
+ def fourier_shell_correlation(density1: "Density", density2: "Density") -> NDArray:
2252
+ """
2253
+ Computes the Fourier Shell Correlation (FSC) between two instances of `Density`.
2254
+
2255
+ The Fourier transforms of the input maps are divided into shells
2256
+ based on their spatial frequency. The correlation between corresponding shells
2257
+ in the two maps is computed to give the FSC.
2258
+
2259
+ Parameters
2260
+ ----------
2261
+ density1 : :py:class:`Density`
2262
+ Reference for comparison.
2263
+ density2 : :py:class:`Density`
2264
+ Target for comparison.
2265
+
2266
+ Returns
2267
+ -------
2268
+ NDArray
2269
+ An array of shape (N, 2), where N is the number of shells.
2270
+ The first column represents the spatial frequency for each shell
2271
+ and the second column represents the corresponding FSC.
2272
+
2273
+ References
2274
+ ----------
2275
+ .. [1] https://github.com/tdgrant1/denss/blob/master/saxstats/saxstats.py
2276
+ """
2277
+ side = density1.data.shape[0]
2278
+ df = 1.0 / side
2279
+
2280
+ qx_ = np.fft.fftfreq(side) * side * df
2281
+ qx, qy, qz = np.meshgrid(qx_, qx_, qx_, indexing="ij")
2282
+ qr = np.sqrt(qx**2 + qy**2 + qz**2)
2283
+
2284
+ qmax = np.max(qr)
2285
+ qstep = np.min(qr[qr > 0])
2286
+ nbins = int(qmax / qstep)
2287
+ qbins = np.linspace(0, nbins * qstep, nbins + 1)
2288
+ qbin_labels = np.searchsorted(qbins, qr, "right") - 1
2289
+
2290
+ F1 = np.fft.fftn(density1.data)
2291
+ F2 = np.fft.fftn(density2.data)
2292
+
2293
+ qbin_labels = qbin_labels.reshape(-1)
2294
+ numerator = np.bincount(
2295
+ qbin_labels, weights=np.real(F1 * np.conj(F2)).reshape(-1)
2296
+ )
2297
+ term1 = np.bincount(qbin_labels, weights=np.abs(F1).reshape(-1) ** 2)
2298
+ term2 = np.bincount(qbin_labels, weights=np.abs(F2).reshape(-1) ** 2)
2299
+ np.multiply(term1, term2, out=term1)
2300
+ denominator = np.sqrt(term1)
2301
+ FSC = np.divide(numerator, denominator)
2302
+
2303
+ qidx = np.where(qbins < qx.max())
2304
+
2305
+ return np.vstack((qbins[qidx], FSC[qidx])).T
Binary file
tme/filters/_utils.py CHANGED
@@ -1,5 +1,5 @@
1
1
  """
2
- Utilities for the generation of frequency grids.
2
+ Utilities for the creation of composable filters.
3
3
 
4
4
  Copyright (c) 2024 European Molecular Biology Laboratory
5
5
 
@@ -79,6 +79,7 @@ def frequency_grid_at_angle(
79
79
  sampling_rate: Tuple[float],
80
80
  opening_axis: int = None,
81
81
  tilt_axis: int = None,
82
+ fftshift: bool = False,
82
83
  ) -> NDArray:
83
84
  """
84
85
  Generate a frequency grid from 0 to 1/(2 * sampling_rate) in each axis.
@@ -94,13 +95,16 @@ def frequency_grid_at_angle(
94
95
  shape : tuple of int
95
96
  The shape of the grid.
96
97
  angle : float
97
- The angle at which to generate the grid.
98
+ The angle at which to generate the grid in degrees.
98
99
  sampling_rate : tuple of float
99
100
  The sampling rate for each dimension.
100
101
  opening_axis : int, optional
101
102
  The projection axis, defaults to None.
102
103
  tilt_axis : int, optional
103
104
  The axis along which the grid is tilted, defaults to None.
105
+ fftshift : bool, optional
106
+ Whether to return a grid centered at shape // 2. Default is grid centered around
107
+ origin, which is compliant with the rfftn definitions used in this project
104
108
 
105
109
  Returns
106
110
  -------
@@ -114,14 +118,17 @@ def frequency_grid_at_angle(
114
118
  shape=shape, opening_axis=opening_axis, reduce_dim=False
115
119
  )
116
120
 
117
- if angle == 0:
121
+ missing_axes = opening_axis is None or tilt_axis is None
122
+ if angle == 0 or missing_axes or len(set(shape)) == 1:
123
+ # Crop the sampling rate to tilt shape
118
124
  sampling_rate = compute_tilt_shape(
119
125
  shape=sampling_rate, opening_axis=opening_axis, reduce_dim=True
120
126
  )
121
- index_grid = fftfreqn(
127
+ return fftfreqn(
122
128
  tuple(x for x in tilt_shape if x != 1),
123
129
  sampling_rate=sampling_rate,
124
130
  compute_euclidean_norm=True,
131
+ fftshift=fftshift,
125
132
  )
126
133
 
127
134
  if angle != 0:
@@ -130,9 +137,11 @@ def frequency_grid_at_angle(
130
137
 
131
138
  angles = np.zeros(len(shape))
132
139
  angles[tilt_axis] = angle
133
- rotation_matrix = euler_to_rotationmatrix(np.roll(angles, opening_axis - 1))
140
+ rotation_matrix = euler_to_rotationmatrix(
141
+ np.roll(angles, opening_axis - 1), seq="zyz"
142
+ )
134
143
 
135
- index_grid = fftfreqn(tilt_shape, sampling_rate=None)
144
+ index_grid = fftfreqn(tilt_shape, sampling_rate=None, fftshift=fftshift)
136
145
  index_grid = np.einsum("ij,j...->i...", rotation_matrix, index_grid)
137
146
  norm = np.multiply(sampling_rate, shape).astype(int)
138
147
 
@@ -146,19 +155,25 @@ def frequency_grid_at_angle(
146
155
  def fftfreqn(
147
156
  shape: Tuple[int],
148
157
  sampling_rate: Tuple[float],
158
+ fftshift: bool = False,
149
159
  compute_euclidean_norm: bool = False,
150
160
  shape_is_real_fourier: bool = False,
151
161
  return_sparse_grid: bool = False,
152
162
  ) -> NDArray:
153
163
  """
154
- Generate the n-dimensional discrete Fourier transform sample frequencies.
164
+ Generate n-dimensional (frequency) grids.
155
165
 
156
166
  Parameters
157
167
  ----------
158
168
  shape : Tuple[int]
159
169
  The shape of the data.
160
170
  sampling_rate : float or Tuple[float]
161
- The sampling rate.
171
+ Sets the maximum value along each axis in shape to x=1/(2*sampling_rate), e.g.,
172
+ a sampling_rate of 1 yields a grid from -n/x * 1/n to (n)/x -1 * 1/n. A sampling
173
+ rate of None returns a grid from -n/2 to n/2 - 1
174
+ fftshift : bool, optional
175
+ Whether to return a grid centered at shape // 2. Default is grid centered around
176
+ origin, which is compliant with the rfftn definitions used in this project.
162
177
  compute_euclidean_norm : bool, optional
163
178
  Whether to compute the Euclidean norm, defaults to False.
164
179
  shape_is_real_fourier : bool, optional
@@ -181,10 +196,18 @@ def fftfreqn(
181
196
  if sampling_rate is not None:
182
197
  norm[-1] = (shape[-1] - 1) * 2 * sampling_rate
183
198
 
184
- grids = []
199
+ ndim, grids = len(shape), []
185
200
  for i, x in enumerate(shape):
186
201
  baseline_dims = tuple(1 if i != t else x for t in range(len(shape)))
187
202
  grid = (np_be.arange(x, dtype=np_be._int_dtype) - center[i]) / norm[i]
203
+
204
+ # We have to invert because we build the grid centered around shape // 2
205
+ if not fftshift:
206
+ if shape_is_real_fourier and i == (ndim - 1):
207
+ pass
208
+ else:
209
+ grid = np.fft.ifftshift(grid)
210
+
188
211
  grid = np_be.astype(grid, np_be._float_dtype)
189
212
  grids.append(np_be.reshape(grid, baseline_dims))
190
213
 
@@ -197,9 +220,7 @@ def fftfreqn(
197
220
  return grids
198
221
 
199
222
  grid_flesh = np_be.full(shape, fill_value=1, dtype=np_be._float_dtype)
200
- grids = np_be.stack(tuple(grid * grid_flesh for grid in grids))
201
-
202
- return grids
223
+ return np_be.stack(tuple(grid * grid_flesh for grid in grids))
203
224
 
204
225
 
205
226
  def crop_real_fourier(data: BackendArray) -> BackendArray:
@@ -231,27 +252,28 @@ def compute_fourier_shape(
231
252
 
232
253
 
233
254
  def shift_fourier(
234
- data: BackendArray, shape_is_real_fourier: bool = False
255
+ data: BackendArray, shape_is_real_fourier: bool = False, ifftshift: bool = True
235
256
  ) -> BackendArray:
236
257
  comp = be
237
258
  if isinstance(data, np.ndarray):
238
259
  comp = NumpyFFTWBackend()
260
+
239
261
  shape = comp.to_backend_array(data.shape)
240
- shift = comp.add(comp.divide(shape, 2), comp.mod(shape, 2))
262
+ shift = comp.divide(shape, 2)
263
+ if ifftshift:
264
+ shift = comp.add(shift, comp.mod(shape, 2))
265
+
241
266
  shift = [int(x) for x in shift]
242
267
  if shape_is_real_fourier:
243
268
  shift[-1] = 0
244
-
245
- data = comp.roll(data, shift, tuple(i for i in range(len(shift))))
246
- return data
269
+ return comp.roll(data, shift, tuple(i for i in range(len(shift))))
247
270
 
248
271
 
249
272
  def create_reconstruction_filter(
250
- filter_shape: Tuple[int], filter_type: str, **kwargs: Dict
273
+ filter_shape: Tuple[int], filter_type: str, fftshift: bool = True, **kwargs: Dict
251
274
  ):
252
275
  """
253
- Create a reconstruction filter of given filter_type. The DC component of
254
- the filter will be located in the array center.
276
+ Create a reconstruction filter of given filter_type.
255
277
 
256
278
  Parameters
257
279
  ----------
@@ -274,6 +296,8 @@ def create_reconstruction_filter(
274
296
  +---------------+----------------------------------------------------+
275
297
  | hamming | |w| * (.54 + .46 ( cos(|w| * pi))) [2]_ |
276
298
  +---------------+----------------------------------------------------+
299
+ fftshift : bool, optional
300
+ Should the DC component be located at the center, default is True.
277
301
  kwargs: Dict
278
302
  Keyword arguments for particular filter_types.
279
303
 
@@ -288,7 +312,9 @@ def create_reconstruction_filter(
288
312
  .. [2] https://odlgroup.github.io/odl/index.html
289
313
  """
290
314
  filter_type = str(filter_type).lower()
291
- freq = fftfreqn(filter_shape, sampling_rate=0.5, compute_euclidean_norm=True)
315
+ freq = fftfreqn(
316
+ filter_shape, sampling_rate=0.5, compute_euclidean_norm=True, fftshift=fftshift
317
+ )
292
318
 
293
319
  if filter_type == "ram-lak":
294
320
  ret = np.copy(freq)
@@ -297,8 +323,8 @@ def create_reconstruction_filter(
297
323
  for dim, size in enumerate(filter_shape):
298
324
  n = np.concatenate(
299
325
  (
300
- np.arange(1, size / 2 + 1, 2, dtype=int),
301
- np.arange(size / 2 - 1, 0, -2, dtype=int),
326
+ np.arange(1, size // 2 + 1, 2, dtype=int),
327
+ np.arange(size // 2 - 1, 0, -2, dtype=int),
302
328
  )
303
329
  )
304
330
  ret1d = np.zeros(size)
@@ -316,7 +342,9 @@ def create_reconstruction_filter(
316
342
  if tilt_angles is False:
317
343
  raise ValueError("'ramp' filter requires specifying tilt angles.")
318
344
  size = filter_shape[0]
319
- ret = fftfreqn((size,), sampling_rate=1, compute_euclidean_norm=True)
345
+ ret = fftfreqn(
346
+ (size,), sampling_rate=1, compute_euclidean_norm=True, fftshift=fftshift
347
+ )
320
348
  min_increment = np.radians(np.min(np.abs(np.diff(np.sort(tilt_angles)))))
321
349
  ret *= min_increment * size
322
350
  ret = np.fmin(ret, 1, out=ret)