pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/match_template.py +28 -39
  2. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/postprocess.py +35 -21
  3. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/preprocessor_gui.py +95 -24
  4. pytme-0.3.1.post1.data/scripts/pytme_runner.py +1223 -0
  5. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/METADATA +5 -7
  6. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/RECORD +55 -48
  7. scripts/extract_candidates.py +118 -99
  8. scripts/match_template.py +28 -39
  9. scripts/postprocess.py +35 -21
  10. scripts/preprocessor_gui.py +95 -24
  11. scripts/pytme_runner.py +644 -190
  12. scripts/refine_matches.py +156 -386
  13. tests/data/.DS_Store +0 -0
  14. tests/data/Blurring/.DS_Store +0 -0
  15. tests/data/Maps/.DS_Store +0 -0
  16. tests/data/Raw/.DS_Store +0 -0
  17. tests/data/Structures/.DS_Store +0 -0
  18. tests/preprocessing/test_utils.py +18 -0
  19. tests/test_analyzer.py +2 -3
  20. tests/test_backends.py +3 -9
  21. tests/test_density.py +0 -1
  22. tests/test_extensions.py +0 -1
  23. tests/test_matching_utils.py +10 -60
  24. tests/test_rotations.py +1 -1
  25. tme/__version__.py +1 -1
  26. tme/analyzer/_utils.py +4 -4
  27. tme/analyzer/aggregation.py +35 -15
  28. tme/analyzer/peaks.py +11 -10
  29. tme/backends/_jax_utils.py +26 -13
  30. tme/backends/_numpyfftw_utils.py +270 -0
  31. tme/backends/cupy_backend.py +16 -55
  32. tme/backends/jax_backend.py +76 -37
  33. tme/backends/matching_backend.py +17 -51
  34. tme/backends/mlx_backend.py +1 -27
  35. tme/backends/npfftw_backend.py +71 -65
  36. tme/backends/pytorch_backend.py +1 -26
  37. tme/density.py +2 -6
  38. tme/extensions.cpython-311-darwin.so +0 -0
  39. tme/filters/ctf.py +22 -21
  40. tme/filters/wedge.py +10 -7
  41. tme/mask.py +341 -0
  42. tme/matching_data.py +31 -19
  43. tme/matching_exhaustive.py +37 -47
  44. tme/matching_optimization.py +2 -1
  45. tme/matching_scores.py +229 -411
  46. tme/matching_utils.py +73 -422
  47. tme/memory.py +1 -1
  48. tme/orientations.py +13 -8
  49. tme/rotations.py +1 -1
  50. pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
  51. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/estimate_memory_usage.py +0 -0
  52. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/preprocess.py +0 -0
  53. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/WHEEL +0 -0
  54. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/entry_points.txt +0 -0
  55. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/licenses/LICENSE +0 -0
  56. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ Copyright (c) 2024 European Molecular Biology Laboratory
6
6
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
 
9
- from typing import Tuple, List, Callable
9
+ from typing import Tuple, List
10
10
 
11
11
  import numpy as np
12
12
 
@@ -144,32 +144,6 @@ class MLXBackend(NumpyFFTWBackend):
144
144
  box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
145
145
  return arr[box]
146
146
 
147
- def build_fft(
148
- self,
149
- fwd_shape: Tuple[int],
150
- inv_shape: Tuple[int] = None,
151
- inv_output_shape: Tuple[int] = None,
152
- fwd_axes: Tuple[int] = None,
153
- inv_axes: Tuple[int] = None,
154
- **kwargs,
155
- ) -> Tuple[Callable, Callable]:
156
- # Runs on mlx.core.cpu until Metal support is available
157
- rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
158
- irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
159
- irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
160
-
161
- def rfftn(arr: MlxArray, out: MlxArray = None, s=rfft_shape, axes=fwd_axes):
162
- out[:] = self._array_backend.fft.rfftn(
163
- arr, s=s, axes=axes, stream=self._array_backend.cpu
164
- )
165
-
166
- def irfftn(arr: MlxArray, out: MlxArray = None, s=irfft_shape, axes=inv_axes):
167
- out[:] = self._array_backend.fft.irfftn(
168
- arr, s=s, axes=axes, stream=self._array_backend.cpu
169
- )
170
-
171
- return rfftn, irfftn
172
-
173
147
  def rfftn(self, arr, *args, **kwargs):
174
148
  return self.fft.rfftn(arr, stream=self._array_backend.cpu, **kwargs)
175
149
 
@@ -9,17 +9,23 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
9
9
  import os
10
10
  from psutil import virtual_memory
11
11
  from contextlib import contextmanager
12
- from typing import Tuple, Dict, List, Type
12
+ from typing import Tuple, List, Type
13
13
 
14
14
  import scipy
15
15
  import numpy as np
16
16
  from scipy.ndimage import maximum_filter, affine_transform
17
- from pyfftw.builders import rfftn as rfftn_builder, irfftn as irfftn_builder
18
- from pyfftw import zeros_aligned, simd_alignment, FFTW, next_fast_len, interfaces
17
+ from pyfftw import (
18
+ zeros_aligned,
19
+ simd_alignment,
20
+ next_fast_len,
21
+ interfaces,
22
+ config,
23
+ )
19
24
 
20
25
  from ..types import NDArray, BackendArray, shm_type
21
26
  from .matching_backend import MatchingBackend, _create_metafunction
22
27
 
28
+
23
29
  os.environ["MKL_NUM_THREADS"] = "1"
24
30
  os.environ["OMP_NUM_THREADS"] = "1"
25
31
  os.environ["PYFFTW_NUM_THREADS"] = "1"
@@ -103,6 +109,20 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
103
109
  self.solve_triangular = self._solve_triangular
104
110
  self.linalg.solve_triangular = scipy.linalg.solve_triangular
105
111
 
112
+ try:
113
+ from ._numpyfftw_utils import rfftn as rfftn_cache
114
+ from ._numpyfftw_utils import irfftn as irfftn_cache
115
+
116
+ self._rfftn = rfftn_cache
117
+ self._irfftn = irfftn_cache
118
+ except Exception as e:
119
+ print(e)
120
+
121
+ config.NUM_THREADS = 1
122
+ config.PLANNER_EFFORT = "FFTW_MEASURE"
123
+ interfaces.cache.enable()
124
+ interfaces.cache.set_keepalive_time(360)
125
+
106
126
  def _linalg_cholesky(self, arr, lower=False, *args, **kwargs):
107
127
  # Upper argument is not supported until numpy 2.0
108
128
  ret = self._array_backend.linalg.cholesky(arr, *args, **kwargs)
@@ -138,7 +158,7 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
138
158
  return float
139
159
 
140
160
  def free_cache(self):
141
- pass
161
+ interfaces.cache.disable()
142
162
 
143
163
  def transpose(self, arr: NDArray, *args, **kwargs) -> NDArray:
144
164
  return self._array_backend.transpose(arr, *args, **kwargs)
@@ -181,6 +201,9 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
181
201
  sorted_indices = self.unravel_index(indices=sorted_indices, shape=arr.shape)
182
202
  return sorted_indices
183
203
 
204
+ def ssum(self, arr, *args, **kwargs):
205
+ return self.sum(self.square(arr), *args, **kwargs)
206
+
184
207
  def indices(self, *args, **kwargs) -> NDArray:
185
208
  return self._array_backend.indices(*args, **kwargs)
186
209
 
@@ -240,70 +263,53 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
240
263
  b[tuple(bind)] = arr[tuple(aind)]
241
264
  return b
242
265
 
243
- def build_fft(
244
- self,
245
- fwd_shape: Tuple[int],
246
- inv_shape: Tuple[int],
247
- real_dtype: type,
248
- cmpl_dtype: type,
249
- fftargs: Dict = {},
250
- inv_output_shape: Tuple[int] = None,
251
- temp_fwd: NDArray = None,
252
- temp_inv: NDArray = None,
253
- fwd_axes: Tuple[int] = None,
254
- inv_axes: Tuple[int] = None,
255
- ) -> Tuple[FFTW, FFTW]:
256
- if temp_fwd is None:
257
- temp_fwd = (
258
- self.zeros(fwd_shape, real_dtype) if temp_fwd is None else temp_fwd
259
- )
260
- if temp_inv is None:
261
- temp_inv = (
262
- self.zeros(inv_shape, cmpl_dtype) if temp_inv is None else temp_inv
263
- )
264
-
265
- default_values = {
266
- "planner_effort": "FFTW_MEASURE",
267
- "auto_align_input": False,
268
- "auto_contiguous": False,
269
- "avoid_copy": True,
270
- "overwrite_input": True,
271
- "threads": 1,
272
- }
273
- for key in default_values:
274
- if key in fftargs:
275
- continue
276
- fftargs[key] = default_values[key]
277
-
278
- rfft_shape = self._format_fft_shape(temp_fwd.shape, fwd_axes)
279
- _rfftn = rfftn_builder(temp_fwd, s=rfft_shape, axes=fwd_axes, **fftargs)
280
- overwrite_input = fftargs.pop("overwrite_input", None)
281
-
282
- irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
283
- irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
284
- _irfftn = irfftn_builder(temp_inv, s=irfft_shape, axes=inv_axes, **fftargs)
285
-
286
- def _rfftn_wrapper(arr, out, *args, **kwargs):
287
- return _rfftn(arr, out)
288
-
289
- def _irfftn_wrapper(arr, out, *args, **kwargs):
290
- return _irfftn(arr, out)
291
-
292
- fftargs["overwrite_input"] = overwrite_input
293
- return _rfftn_wrapper, _irfftn_wrapper
266
+ def _rfftn(self, arr, out=None, **kwargs):
267
+ ret = interfaces.numpy_fft.rfftn(arr, **kwargs)
268
+ if out is not None:
269
+ out[:] = ret
270
+ return out
271
+ return ret
294
272
 
295
- @staticmethod
296
- def _format_fft_shape(shape: Tuple[int], axes: Tuple[int] = None):
297
- if axes is None:
298
- return shape
299
- axes = tuple(sorted(range(len(shape))[i] for i in axes))
300
- return tuple(shape[i] for i in axes)
273
+ def _irfftn(self, arr, out=None, **kwargs):
274
+ ret = interfaces.numpy_fft.irfftn(arr, **kwargs)
275
+ if out is not None:
276
+ out[:] = ret
277
+ return out
278
+ return ret
301
279
 
302
- def rfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
303
- return interfaces.numpy_fft.rfftn(arr, **kwargs)
280
+ def rfftn(
281
+ self,
282
+ arr: NDArray,
283
+ out=None,
284
+ auto_align_input: bool = False,
285
+ auto_contiguous: bool = False,
286
+ overwrite_input: bool = True,
287
+ **kwargs,
288
+ ) -> NDArray:
289
+ return self._rfftn(
290
+ arr,
291
+ auto_align_input=auto_align_input,
292
+ auto_contiguous=auto_contiguous,
293
+ overwrite_input=overwrite_input,
294
+ **kwargs,
295
+ )
304
296
 
305
- def irfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
306
- return interfaces.numpy_fft.irfftn(arr, **kwargs)
297
+ def irfftn(
298
+ self,
299
+ arr: NDArray,
300
+ out=None,
301
+ auto_align_input: bool = False,
302
+ auto_contiguous: bool = False,
303
+ overwrite_input: bool = True,
304
+ **kwargs,
305
+ ) -> NDArray:
306
+ return self._irfftn(
307
+ arr,
308
+ auto_align_input=auto_align_input,
309
+ auto_contiguous=auto_contiguous,
310
+ overwrite_input=overwrite_input,
311
+ **kwargs,
312
+ )
307
313
 
308
314
  def extract_center(self, arr: NDArray, newshape: Tuple[int]) -> NDArray:
309
315
  new_shape = self.to_backend_array(newshape)
@@ -7,7 +7,7 @@ Copyright (c) 2023 European Molecular Biology Laboratory
7
7
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
8
8
  """
9
9
 
10
- from typing import Tuple, Callable
10
+ from typing import Tuple
11
11
  from contextlib import contextmanager
12
12
  from multiprocessing import shared_memory
13
13
  from multiprocessing.managers import SharedMemoryManager
@@ -273,31 +273,6 @@ class PytorchBackend(NumpyFFTWBackend):
273
273
  kwargs["device"] = self.device
274
274
  return self._array_backend.eye(*args, **kwargs)
275
275
 
276
- def build_fft(
277
- self,
278
- fwd_shape: Tuple[int],
279
- inv_shape: Tuple[int],
280
- inv_output_shape: Tuple[int] = None,
281
- fwd_axes: Tuple[int] = None,
282
- inv_axes: Tuple[int] = None,
283
- **kwargs,
284
- ) -> Tuple[Callable, Callable]:
285
- rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
286
- irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
287
- irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
288
-
289
- def rfftn(
290
- arr: TorchTensor, out: TorchTensor, s=rfft_shape, axes=fwd_axes
291
- ) -> TorchTensor:
292
- return self._array_backend.fft.rfftn(arr, s=s, out=out, dim=axes)
293
-
294
- def irfftn(
295
- arr: TorchTensor, out: TorchTensor = None, s=irfft_shape, axes=inv_axes
296
- ) -> TorchTensor:
297
- return self._array_backend.fft.irfftn(arr, s=s, out=out, dim=axes)
298
-
299
- return rfftn, irfftn
300
-
301
276
  def rfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
302
277
  kwargs["dim"] = kwargs.pop("axes", None)
303
278
  return self._array_backend.fft.rfftn(arr, **kwargs)
tme/density.py CHANGED
@@ -36,6 +36,7 @@ from .matching_utils import (
36
36
  array_to_memmap,
37
37
  memmap_to_array,
38
38
  minimum_enclosing_box,
39
+ is_gzipped,
39
40
  )
40
41
 
41
42
  __all__ = ["Density"]
@@ -331,6 +332,7 @@ class Density:
331
332
  if non_standard_crs:
332
333
  data = np.transpose(data, crs_index)
333
334
  origin = np.take(origin, crs_index)
335
+ sampling_rate = np.take(sampling_rate, crs_index)
334
336
 
335
337
  return data.T, origin[::-1], sampling_rate[::-1], metadata
336
338
 
@@ -2257,9 +2259,3 @@ class Density:
2257
2259
  coordinates = np.array(np.where(data > 0))
2258
2260
  weights = self.data[tuple(coordinates)]
2259
2261
  return align_to_axis(coordinates.T, weights=weights, axis=axis, flip=flip)
2260
-
2261
-
2262
- def is_gzipped(filename: str) -> bool:
2263
- """Check if a file is a gzip file by reading its magic number."""
2264
- with open(filename, "rb") as f:
2265
- return f.read(2) == b"\x1f\x8b"
Binary file
tme/filters/ctf.py CHANGED
@@ -36,7 +36,7 @@ class CTF(ComposableFilter):
36
36
 
37
37
  #: The shape of the to-be created mask.
38
38
  shape: Tuple[int] = None
39
- #: The defocus value in x direction (in units of sampling rate).
39
+ #: The defocus in x direction (in units of sampling rate).
40
40
  defocus_x: Tuple[float] = None
41
41
  #: The tilt angles.
42
42
  angles: Tuple[float] = None
@@ -164,7 +164,7 @@ class CTF(ComposableFilter):
164
164
  shape : tuple of int
165
165
  The shape of the CTF.
166
166
  defocus_x : tuple of float
167
- The defocus value in x direction.
167
+ The defocus in x direction (in units of sampling rate).
168
168
  angles : tuple of float
169
169
  The tilt angles.
170
170
  opening_axis : int, optional
@@ -178,7 +178,7 @@ class CTF(ComposableFilter):
178
178
  defocus_angle : tuple of float, optional
179
179
  The defocus angle in radians, defaults to 0.
180
180
  defocus_y : tuple of float, optional
181
- The defocus value in y direction, defaults to None.
181
+ The defocus in x direction (in units of sampling rate).
182
182
  correct_defocus_gradient : bool, optional
183
183
  Whether to correct defocus gradient, defaults to False.
184
184
  sampling_rate : tuple of float, optional
@@ -219,14 +219,12 @@ class CTF(ComposableFilter):
219
219
  corrected_tilt_axis -= 1
220
220
 
221
221
  for index, angle in enumerate(angles):
222
- defocus_x, defocus_y = defoci_x[index], defoci_y[index]
223
-
224
222
  correction = correct_defocus_gradient and angle is not None
225
223
  chi = create_ctf(
226
224
  angle=angle,
227
225
  shape=ctf_shape,
228
- defocus_x=defocus_x,
229
- defocus_y=defocus_y,
226
+ defocus_x=defoci_x[index],
227
+ defocus_y=defoci_y[index],
230
228
  sampling_rate=sampling_rate,
231
229
  acceleration_voltage=acceleration_voltage[index],
232
230
  correct_defocus_gradient=correction,
@@ -243,12 +241,10 @@ class CTF(ComposableFilter):
243
241
  stack[index] = chi
244
242
 
245
243
  # Avoid contrast inversion
246
- np.negative(stack, out=stack)
244
+ stack = np.negative(stack, out=stack)
247
245
  if flip_phase:
248
- np.abs(stack, out=stack)
249
-
250
- stack = be.to_backend_array(np.squeeze(stack))
251
- return stack
246
+ stack = np.abs(stack, out=stack)
247
+ return be.to_backend_array(np.squeeze(stack))
252
248
 
253
249
 
254
250
  class CTFReconstructed(CTF):
@@ -281,7 +277,7 @@ class CTFReconstructed(CTF):
281
277
  shape : tuple of int
282
278
  The shape of the CTF.
283
279
  defocus_x : tuple of float
284
- The defocus value in x direction.
280
+ The defocus in x direction in units of sampling rate.
285
281
  opening_axis : int, optional
286
282
  The axis around which the wedge is opened, defaults to 2.
287
283
  amplitude_contrast : float, optional
@@ -291,7 +287,7 @@ class CTFReconstructed(CTF):
291
287
  defocus_angle : tuple of float, optional
292
288
  The defocus angle in radians, defaults to 0.
293
289
  defocus_y : tuple of float, optional
294
- The defocus value in y direction, defaults to None.
290
+ The defocus in y direction in units of sampling rate.
295
291
  sampling_rate : tuple of float, optional
296
292
  The sampling rate, defaults to 1.
297
293
  acceleration_voltage : float, optional
@@ -321,18 +317,15 @@ class CTFReconstructed(CTF):
321
317
  defocus_angle=defocus_angle,
322
318
  amplitude_contrast=amplitude_contrast,
323
319
  )
324
- stack = shift_fourier(data=stack, shape_is_real_fourier=False)
325
-
326
320
  # Avoid contrast inversion
327
321
  np.negative(stack, out=stack)
328
322
  if flip_phase:
329
323
  np.abs(stack, out=stack)
330
324
 
331
- stack = be.to_backend_array(np.squeeze(stack))
325
+ stack = shift_fourier(data=stack, shape_is_real_fourier=False)
332
326
  if return_real_fourier:
333
327
  stack = crop_real_fourier(stack)
334
-
335
- return stack
328
+ return be.to_backend_array(np.squeeze(stack))
336
329
 
337
330
 
338
331
  def _from_xml(filename: str) -> Dict:
@@ -436,6 +429,9 @@ def _from_ctffind(filename: str) -> Dict:
436
429
  output[key] = np.array(output[key])
437
430
 
438
431
  output["additional_phase_shift"] = np.degrees(output["additional_phase_shift"])
432
+ cs = output.get("spherical_aberration")
433
+ if cs is not None:
434
+ output["spherical_aberration"] = float(cs) * 1e7
439
435
  return output
440
436
 
441
437
 
@@ -566,7 +562,7 @@ def create_ctf(
566
562
  amplitude_contrast : float, optional
567
563
  Amplitude contrast of microscope, defaults to 0.07.
568
564
  spherical_aberration : float, optional
569
- Spherical aberration of microscope in Angstrom.
565
+ Spherical aberration of microscope in units of sampling rate.
570
566
  angle : float, optional
571
567
  Assume the created CTF is a projection over opening_axis observed at angle.
572
568
  opening_axis : int, optional
@@ -590,10 +586,14 @@ def create_ctf(
590
586
  electron_wavelength = _compute_electron_wavelength(acceleration_voltage)
591
587
  electron_wavelength /= sampling_rate
592
588
  aberration = (spherical_aberration / sampling_rate) * electron_wavelength**2
589
+
590
+ defocus_x = defocus_x / sampling_rate if defocus_x is not None else None
591
+ defocus_y = defocus_y / sampling_rate if defocus_y is not None else None
593
592
  if correct_defocus_gradient or defocus_y is not None:
594
593
  if len(shape) < 2:
595
594
  raise ValueError(f"Length of shape needs to be at least 2, got {shape}")
596
595
 
596
+ # Axial distance from grid center in multiples of sampling rate
597
597
  sampling = tuple(float(x) for x in np.divide(sampling_rate, shape))
598
598
  grid = fftfreqn(
599
599
  shape=shape,
@@ -619,6 +619,7 @@ def create_ctf(
619
619
  defocus_sum = np.add(defocus_x, defocus_y)
620
620
  defocus_difference = np.subtract(defocus_x, defocus_y)
621
621
 
622
+ # Reusing grid, but in principle pure frequencies would suffice
622
623
  angular_grid = np.arctan2(grid[1], grid[0])
623
624
  defocus_difference = np.multiply(
624
625
  defocus_difference,
@@ -627,7 +628,7 @@ def create_ctf(
627
628
  defocus_x = np.add(defocus_sum, defocus_difference)
628
629
  defocus_x *= 0.5
629
630
 
630
- frequency_grid = fftfreqn(shape, sampling_rate=True, compute_euclidean_norm=True)
631
+ frequency_grid = fftfreqn(shape, sampling_rate=1, compute_euclidean_norm=True)
631
632
  if angle is not None and opening_axis is not None and full_shape is not None:
632
633
  frequency_grid = frequency_grid_at_angle(
633
634
  shape=full_shape,
tme/filters/wedge.py CHANGED
@@ -15,7 +15,7 @@ import numpy as np
15
15
  from ..types import NDArray
16
16
  from ..backends import backend as be
17
17
  from .compose import ComposableFilter
18
- from ..matching_utils import centered
18
+ from ..matching_utils import center_slice
19
19
  from ..rotations import euler_to_rotationmatrix
20
20
  from ..parser import XMLParser, StarParser, MDOCParser
21
21
  from ._utils import (
@@ -207,11 +207,10 @@ class Wedge(ComposableFilter):
207
207
  )
208
208
  sigma = np.sqrt(self.weights[index] * 4 / (8 * np.pi**2))
209
209
  sigma = -2 * np.pi**2 * sigma**2
210
- np.square(frequency_grid, out=frequency_grid)
211
- np.multiply(sigma, frequency_grid, out=frequency_grid)
212
- np.exp(frequency_grid, out=frequency_grid)
213
- np.multiply(frequency_grid, np.cos(np.radians(angle)), out=frequency_grid)
214
- wedges[index] = frequency_grid
210
+ frequency_grid = np.square(frequency_grid, out=frequency_grid)
211
+ frequency_grid = np.multiply(sigma, frequency_grid, out=frequency_grid)
212
+ frequency_grid = np.exp(frequency_grid, out=frequency_grid)
213
+ wedges[index] = np.multiply(frequency_grid, np.cos(np.radians(angle)))
215
214
 
216
215
  return wedges
217
216
 
@@ -490,7 +489,11 @@ class WedgeReconstructed:
490
489
  )
491
490
  wedge_volume += plane_rotated * weights[index]
492
491
 
493
- wedge_volume = centered(wedge_volume, (shape[opening_axis], shape[tilt_axis]))
492
+ subset = center_slice(
493
+ wedge_volume.shape, (shape[opening_axis], shape[tilt_axis])
494
+ )
495
+ wedge_volume = wedge_volume[subset]
496
+
494
497
  np.fmin(wedge_volume, np.max(weights), wedge_volume)
495
498
 
496
499
  if opening_axis > tilt_axis: