pytme 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2.dev0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pytme-0.3.2.dev0.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/match_template.py +213 -196
  3. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/postprocess.py +40 -78
  4. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/preprocess.py +4 -5
  5. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/preprocessor_gui.py +50 -103
  6. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/pytme_runner.py +46 -69
  7. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/METADATA +2 -1
  8. pytme-0.3.2.dev0.dist-info/RECORD +136 -0
  9. scripts/estimate_ram_usage.py +97 -0
  10. scripts/match_template.py +213 -196
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +40 -78
  13. scripts/preprocess.py +4 -5
  14. scripts/preprocessor_gui.py +50 -103
  15. scripts/pytme_runner.py +46 -69
  16. scripts/refine_matches.py +5 -7
  17. tests/preprocessing/test_compose.py +31 -30
  18. tests/preprocessing/test_frequency_filters.py +17 -32
  19. tests/preprocessing/test_preprocessor.py +0 -19
  20. tests/preprocessing/test_utils.py +13 -1
  21. tests/test_analyzer.py +2 -10
  22. tests/test_backends.py +47 -18
  23. tests/test_density.py +72 -13
  24. tests/test_extensions.py +1 -0
  25. tests/test_matching_cli.py +23 -9
  26. tests/test_matching_exhaustive.py +5 -5
  27. tests/test_matching_utils.py +3 -3
  28. tests/test_rotations.py +13 -23
  29. tests/test_structure.py +1 -7
  30. tme/__version__.py +1 -1
  31. tme/analyzer/aggregation.py +47 -16
  32. tme/analyzer/base.py +34 -0
  33. tme/analyzer/peaks.py +26 -13
  34. tme/analyzer/proxy.py +14 -0
  35. tme/backends/_jax_utils.py +124 -71
  36. tme/backends/cupy_backend.py +6 -19
  37. tme/backends/jax_backend.py +110 -105
  38. tme/backends/matching_backend.py +0 -17
  39. tme/backends/mlx_backend.py +0 -29
  40. tme/backends/npfftw_backend.py +100 -97
  41. tme/backends/pytorch_backend.py +65 -78
  42. tme/cli.py +2 -2
  43. tme/density.py +102 -58
  44. tme/extensions.cpython-311-darwin.so +0 -0
  45. tme/filters/_utils.py +52 -24
  46. tme/filters/bandpass.py +99 -105
  47. tme/filters/compose.py +133 -39
  48. tme/filters/ctf.py +51 -102
  49. tme/filters/reconstruction.py +67 -122
  50. tme/filters/wedge.py +296 -325
  51. tme/filters/whitening.py +39 -75
  52. tme/mask.py +2 -2
  53. tme/matching_data.py +87 -15
  54. tme/matching_exhaustive.py +70 -120
  55. tme/matching_optimization.py +9 -63
  56. tme/matching_scores.py +261 -100
  57. tme/matching_utils.py +150 -91
  58. tme/memory.py +1 -0
  59. tme/orientations.py +28 -8
  60. tme/preprocessor.py +0 -239
  61. tme/rotations.py +102 -70
  62. tme/structure.py +601 -631
  63. tme/types.py +1 -0
  64. pytme-0.3.1.post1.dist-info/RECORD +0 -133
  65. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/estimate_memory_usage.py +0 -0
  66. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/WHEEL +0 -0
  67. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/entry_points.txt +0 -0
  68. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/licenses/LICENSE +0 -0
  69. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/top_level.txt +0 -0
tme/filters/whitening.py CHANGED
@@ -1,5 +1,5 @@
1
1
  """
2
- Implements class BandPassFilter to create Fourier filter representations.
2
+ Implements class LinearWhiteningFilter
3
3
 
4
4
  Copyright (c) 2024 European Molecular Biology Laboratory
5
5
 
@@ -7,22 +7,26 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
 
9
9
  from typing import Tuple, Dict
10
+ from dataclasses import dataclass
10
11
 
11
12
  import numpy as np
12
13
  from scipy.ndimage import mean as ndimean
13
14
  from scipy.ndimage import map_coordinates
14
15
 
16
+ from ._utils import fftfreqn
15
17
  from ..types import BackendArray
18
+ from ..analyzer.peaks import batchify
16
19
  from ..backends import backend as be
17
20
  from .compose import ComposableFilter
18
- from ._utils import fftfreqn, compute_fourier_shape, shift_fourier
21
+
19
22
 
20
23
  __all__ = ["LinearWhiteningFilter"]
21
24
 
22
25
 
26
+ @dataclass
23
27
  class LinearWhiteningFilter(ComposableFilter):
24
28
  """
25
- Compute Fourier power spectrums and perform whitening.
29
+ Generate Fourier whitening filters.
26
30
 
27
31
  References
28
32
  ----------
@@ -34,12 +38,9 @@ class LinearWhiteningFilter(ComposableFilter):
34
38
  13375 (2023)
35
39
  """
36
40
 
37
- def __init__(self, *args, **kwargs):
38
- pass
39
-
40
41
  @staticmethod
41
42
  def _compute_spectrum(
42
- data_rfft: BackendArray, n_bins: int = None, batch_dimension: int = None
43
+ data_rfft: BackendArray, n_bins: int = None
43
44
  ) -> Tuple[BackendArray, BackendArray]:
44
45
  """
45
46
  Compute the power spectrum of the input data.
@@ -50,8 +51,6 @@ class LinearWhiteningFilter(ComposableFilter):
50
51
  The Fourier transform of the input data.
51
52
  n_bins : int, optional
52
53
  The number of bins for computing the spectrum, defaults to None.
53
- batch_dimension : int, optional
54
- Batch dimension to average over.
55
54
 
56
55
  Returns
57
56
  -------
@@ -60,7 +59,7 @@ class LinearWhiteningFilter(ComposableFilter):
60
59
  radial_averages : BackendArray
61
60
  Array containing the radial averages of the spectrum.
62
61
  """
63
- shape = tuple(x for i, x in enumerate(data_rfft.shape) if i != batch_dimension)
62
+ shape = data_rfft.shape
64
63
 
65
64
  max_bins = max(max(shape[:-1]) // 2 + 1, shape[-1])
66
65
  n_bins = max_bins if n_bins is None else n_bins
@@ -71,25 +70,22 @@ class LinearWhiteningFilter(ComposableFilter):
71
70
  sampling_rate=0.5,
72
71
  shape_is_real_fourier=True,
73
72
  compute_euclidean_norm=True,
73
+ fftshift=False,
74
74
  )
75
75
  bins = be.to_numpy_array(bins)
76
-
77
- # Implicit lowpass to nyquist
78
76
  bins = np.floor(bins * (n_bins - 1) + 0.5).astype(int)
79
- fft_shift_axes = tuple(
80
- i for i in range(data_rfft.ndim - 1) if i != batch_dimension
81
- )
82
- fourier_spectrum = np.fft.fftshift(data_rfft, axes=fft_shift_axes)
83
- fourier_spectrum = np.abs(fourier_spectrum)
84
- np.square(fourier_spectrum, out=fourier_spectrum)
77
+
78
+ fourier_spectrum = np.abs(data_rfft)
79
+ fourier_spectrum = np.square(fourier_spectrum, out=fourier_spectrum)
85
80
 
86
81
  radial_averages = ndimean(
87
82
  fourier_spectrum, labels=bins, index=np.arange(n_bins)
88
83
  )
89
- np.sqrt(radial_averages, out=radial_averages)
90
- np.reciprocal(radial_averages, out=radial_averages)
91
- np.divide(radial_averages, radial_averages.max(), out=radial_averages)
92
-
84
+ radial_averages = np.sqrt(radial_averages, out=radial_averages)
85
+ radial_averages = np.where(radial_averages != 0, 1 / radial_averages, 0)
86
+ norm_factor = radial_averages.max()
87
+ if norm_factor != 0:
88
+ radial_averages = np.divide(radial_averages, norm_factor)
93
89
  return bins, radial_averages
94
90
 
95
91
  @staticmethod
@@ -104,21 +100,19 @@ class LinearWhiteningFilter(ComposableFilter):
104
100
  sampling_rate=0.5,
105
101
  shape_is_real_fourier=shape_is_real_fourier,
106
102
  compute_euclidean_norm=True,
103
+ fftshift=False,
107
104
  )
108
105
  grid = be.to_numpy_array(grid)
109
- np.multiply(grid, (spectrum.shape[0] - 1), out=grid) + 0.5
106
+ grid = np.floor(np.multiply(grid, spectrum.shape[0] - 1) + 0.5)
110
107
  spectrum = map_coordinates(spectrum, grid.reshape(1, -1), order=order)
111
108
  return spectrum.reshape(grid.shape)
112
109
 
113
- def __call__(
110
+ def _evaluate(
114
111
  self,
115
- shape: Tuple[int],
116
- data: BackendArray = None,
117
- data_rfft: BackendArray = None,
118
- n_bins: int = None,
119
- batch_dimension: int = None,
112
+ shape: Tuple[int, ...],
113
+ data_rfft: BackendArray,
114
+ axes: Tuple[int] = (),
120
115
  order: int = 1,
121
- return_real_fourier: bool = True,
122
116
  **kwargs: Dict,
123
117
  ) -> Dict:
124
118
  """
@@ -128,59 +122,29 @@ class LinearWhiteningFilter(ComposableFilter):
128
122
  ----------
129
123
  shape : tuple of ints
130
124
  Shape of the returned whitening filter.
131
- data : BackendArray, optional
132
- The input data, defaults to None.
133
125
  data_rfft : BackendArray, optional
134
126
  The Fourier transform of the input data, defaults to None.
135
- n_bins : int, optional
136
- The number of bins for computing the spectrum, defaults to None.
137
- batch_dimension : int, optional
138
- Batch dimension to average over.
139
- return_real_fourier : tuple of int
140
- Return a shape compliant with rfft, i.e., omit the negative frequencies
141
- terms resulting in a return shape (*shape[:-1], shape[-1]//2+1)
127
+ axes : tuple of ints, optional
128
+ Axes to compute spectrum for independently.
142
129
  **kwargs : Dict
143
130
  Additional keyword arguments.
144
-
145
- Returns
146
- -------
147
- dict
148
- data: BackendArray
149
- The filter mask.
150
- shape: tuple of ints
151
- The requested filter shape
152
- return_real_fourier: bool
153
- Whether data is compliant with rfftn.
154
- is_multiplicative_filter: bool
155
- Whether the filter is multiplicative in Fourier space.
156
131
  """
157
- if data_rfft is None:
158
- data_rfft = be.rfftn(data)
132
+ if isinstance(axes, int):
133
+ axes = (axes,)
159
134
 
135
+ stack = []
160
136
  data_rfft = be.to_numpy_array(data_rfft)
161
- bins, radial_averages = self._compute_spectrum(
162
- data_rfft, n_bins, batch_dimension
163
- )
164
- shape = tuple(int(x) for i, x in enumerate(shape) if i != batch_dimension)
165
-
166
- shape_filter = shape
167
- if return_real_fourier:
168
- shape_filter = compute_fourier_shape(
137
+ for subset, _ in batchify(data_rfft.shape, axes):
138
+ _, radial_avg = self._compute_spectrum(np.squeeze(data_rfft[subset]))
139
+ ret = self._interpolate_spectrum(
140
+ spectrum=radial_avg,
169
141
  shape=shape,
170
142
  shape_is_real_fourier=False,
143
+ order=order,
171
144
  )
145
+ stack.append(ret)
172
146
 
173
- ret = self._interpolate_spectrum(
174
- spectrum=radial_averages,
175
- shape=shape_filter,
176
- shape_is_real_fourier=return_real_fourier,
177
- )
178
-
179
- ret = shift_fourier(data=ret, shape_is_real_fourier=return_real_fourier)
180
-
181
- return {
182
- "data": be.to_backend_array(ret),
183
- "shape": shape,
184
- "return_real_fourier": return_real_fourier,
185
- "is_multiplicative_filter": True,
186
- }
147
+ ret = np.array(stack)
148
+ if not len(axes):
149
+ ret = np.squeeze(ret)
150
+ return {"data": be.to_backend_array(ret), "shape": shape}
tme/mask.py CHANGED
@@ -11,7 +11,7 @@ from typing import Tuple, Optional
11
11
 
12
12
  from .types import NDArray
13
13
  from scipy.ndimage import gaussian_filter
14
- from .matching_utils import rigid_transform
14
+ from .matching_utils import _rigid_transform
15
15
 
16
16
  __all__ = ["elliptical_mask", "tube_mask", "box_mask", "membrane_mask"]
17
17
 
@@ -76,7 +76,7 @@ def elliptical_mask(
76
76
  if orientation is not None:
77
77
  return_shape = indices.shape
78
78
  indices = indices.reshape(n, -1)
79
- rigid_transform(
79
+ _rigid_transform(
80
80
  coordinates=indices,
81
81
  rotation_matrix=np.asarray(orientation),
82
82
  out=indices,
tme/matching_data.py CHANGED
@@ -15,7 +15,7 @@ from . import Density
15
15
  from .filters import Compose
16
16
  from .backends import backend as be
17
17
  from .types import BackendArray, NDArray
18
- from .matching_utils import compute_parallelization_schedule
18
+ from .matching_utils import compute_parallelization_schedule, copy_docstring
19
19
 
20
20
  __all__ = ["MatchingData"]
21
21
 
@@ -249,8 +249,8 @@ class MatchingData:
249
249
  target_offset = np.zeros(len(self._output_target_shape), dtype=int)
250
250
  target_offset[mask] = [x.start for x in target_slice]
251
251
  mask = np.subtract(1, self._target_batch).astype(bool)
252
- template_offset = np.zeros(len(self._output_template_shape), dtype=int)
253
- template_offset[mask] = [x.start for x in template_slice]
252
+ # template_offset = np.zeros(len(self._output_template_shape), dtype=int)
253
+ # template_offset[mask] = [x.start for x in template_slice]
254
254
 
255
255
  translation_offset = tuple(x for x in target_offset)
256
256
 
@@ -485,14 +485,18 @@ class MatchingData:
485
485
 
486
486
  @staticmethod
487
487
  def _fourier_padding(
488
- target_shape: Tuple[int],
489
- template_shape: Tuple[int],
490
- batch_mask: Tuple[int] = None,
488
+ target_shape: NDArray,
489
+ template_shape: NDArray,
490
+ target_batch: NDArray = None,
491
+ template_batch: NDArray = None,
491
492
  **kwargs,
492
493
  ) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
493
- if batch_mask is None:
494
- batch_mask = np.zeros_like(template_shape)
495
- batch_mask = np.asarray(batch_mask)
494
+ if target_batch is None:
495
+ target_batch = np.zeros_like(target_shape)
496
+ if template_batch is None:
497
+ template_batch = np.zeros_like(target_shape)
498
+
499
+ batch_mask = np.logical_or(target_batch, template_batch)
496
500
 
497
501
  fourier_pad = np.ones(len(template_shape), dtype=int)
498
502
  fourier_pad = np.multiply(fourier_pad, 1 - batch_mask)
@@ -500,7 +504,9 @@ class MatchingData:
500
504
 
501
505
  # Avoid padding batch dimensions
502
506
  pad_shape = np.maximum(target_shape, template_shape)
503
- pad_shape = np.maximum(pad_shape, np.multiply(1 - batch_mask, pad_shape))
507
+ pad_shape = np.where(target_batch, target_shape, pad_shape)
508
+ pad_shape = np.where(template_batch, template_shape, pad_shape)
509
+
504
510
  ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
505
511
  conv_shape, fast_shape, fast_ft_shape = ret
506
512
 
@@ -538,10 +544,15 @@ class MatchingData:
538
544
  --------
539
545
  >>> conv, fwd, inv, shift = matching_data.fourier_padding(pad_fourier=True)
540
546
  """
547
+ target_shape = kwargs.get("target_shape", self._output_target_shape)
548
+ template_shape = kwargs.get("template_shape", self._output_template_shape)
549
+ target_batch = kwargs.get("target_batch", self._target_batch)
550
+ template_batch = kwargs.get("template_batch", self._template_batch)
541
551
  return self._fourier_padding(
542
- target_shape=be.to_numpy_array(self._output_target_shape),
543
- template_shape=be.to_numpy_array(self._output_template_shape),
544
- batch_mask=be.to_numpy_array(self._batch_mask),
552
+ target_shape=be.to_numpy_array(target_shape),
553
+ template_shape=be.to_numpy_array(template_shape),
554
+ target_batch=be.to_numpy_array(target_batch),
555
+ template_batch=be.to_numpy_array(template_batch),
545
556
  )
546
557
 
547
558
  def _score_mask(self, fast_shape: Tuple[int], shift: Tuple[int]) -> BackendArray:
@@ -568,6 +579,68 @@ class MatchingData:
568
579
  )
569
580
  return be.to_backend_array(score_mask)
570
581
 
582
+ def _transform_data(
583
+ self, method: str, data: BackendArray, batch_mask: Tuple[int], **kwargs
584
+ ) -> BackendArray:
585
+ """
586
+ Transform data using the specified method.
587
+
588
+ Parameters
589
+ ----------
590
+ method : str, optional
591
+ Transformation method, default "phase_randomization".
592
+ - "phase_randomization": Scrambles phase while preserving amplitude spectrum
593
+ - "standardize": Standardize to zero mean and unit variance
594
+ - "laplace": Applies Laplacian edge detection filter
595
+ **kwargs : dict
596
+ Method-specific arguments (e.g., mode="wrap" for laplace).
597
+
598
+ Returns
599
+ -------
600
+ BackendArray
601
+ Transformed data.
602
+ """
603
+ from scipy.ndimage import laplace
604
+ from .matching_utils import scramble_phases, standardize
605
+
606
+ def _standardize(arr: NDArray, **kwargs) -> NDArray:
607
+ return standardize(arr, 1, arr.size)
608
+
609
+ _supported_methods = {
610
+ "phase_randomization": scramble_phases,
611
+ "laplace": laplace,
612
+ "standardize": _standardize,
613
+ }
614
+ func = _supported_methods.get(method)
615
+ if func is None:
616
+ _supported = ",".join([str(x) for x in _supported_methods])
617
+ raise ValueError(f"Only methods {_supported} are supported.")
618
+
619
+ data = be.to_numpy_array(data)
620
+
621
+ ret = np.zeros_like(data)
622
+ for subset in self._batch_iter(data.shape, batch_mask):
623
+ ret[subset] = func(data[subset], **kwargs)
624
+ return be.to_backend_array(ret)
625
+
626
+ @copy_docstring(_transform_data)
627
+ def transform_target(self, method: str = "phase_randomization", **kwargs):
628
+ return self._transform_data(method, self.target, self._target_batch, **kwargs)
629
+
630
+ @copy_docstring(_transform_data)
631
+ def transform_template(
632
+ self, method: str = "phase_randomization", reverse: bool = False, **kwargs
633
+ ):
634
+ """
635
+ Notes
636
+ -----
637
+ The returned template is in the original not reversed orientation.
638
+ """
639
+ template = self._get_data(
640
+ self._template, self._output_template_shape, reverse, self._template_dim
641
+ )
642
+ return self._transform_data(method, template, self._template_batch, **kwargs)
643
+
571
644
  def computation_schedule(
572
645
  self,
573
646
  matching_method: str = "FLCSphericalMask",
@@ -723,7 +796,6 @@ class MatchingData:
723
796
  _output_shape = self._output_template_shape
724
797
  if np.prod([int(i) for i in template_mask.shape]) != np.prod(_output_shape):
725
798
  _output_shape = self._batch_shape(_output_shape, self._template_batch, True)
726
-
727
799
  return self._get_data(template_mask, _output_shape, True, self._template_dim)
728
800
 
729
801
  @target.setter
@@ -855,7 +927,7 @@ class MatchingData:
855
927
  rot_list.append(self.rotations[init_rot:end_rot])
856
928
  return rot_list
857
929
 
858
- def _free_data(self):
930
+ def free(self):
859
931
  """
860
932
  Dereference data arrays owned by the class instance.
861
933
  """