pytme 0.1.9__cp311-cp311-macosx_14_0_arm64.whl → 0.2.0__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. pytme-0.2.0.data/scripts/match_template.py +1019 -0
  2. pytme-0.2.0.data/scripts/postprocess.py +570 -0
  3. {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/preprocessor_gui.py +244 -60
  4. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/METADATA +3 -1
  5. pytme-0.2.0.dist-info/RECORD +72 -0
  6. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +218 -0
  8. scripts/match_template.py +459 -218
  9. pytme-0.1.9.data/scripts/match_template.py → scripts/match_template_filters.py +459 -218
  10. scripts/postprocess.py +380 -435
  11. scripts/preprocessor_gui.py +244 -60
  12. scripts/refine_matches.py +218 -0
  13. tme/__init__.py +2 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +533 -78
  16. tme/backends/cupy_backend.py +80 -15
  17. tme/backends/npfftw_backend.py +35 -6
  18. tme/backends/pytorch_backend.py +15 -7
  19. tme/density.py +173 -78
  20. tme/extensions.cpython-311-darwin.so +0 -0
  21. tme/matching_constrained.py +195 -0
  22. tme/matching_data.py +76 -33
  23. tme/matching_exhaustive.py +354 -225
  24. tme/matching_memory.py +1 -0
  25. tme/matching_optimization.py +753 -649
  26. tme/matching_utils.py +152 -8
  27. tme/orientations.py +561 -0
  28. tme/preprocessing/__init__.py +2 -0
  29. tme/preprocessing/_utils.py +176 -0
  30. tme/preprocessing/composable_filter.py +30 -0
  31. tme/preprocessing/compose.py +52 -0
  32. tme/preprocessing/frequency_filters.py +322 -0
  33. tme/preprocessing/tilt_series.py +967 -0
  34. tme/preprocessor.py +35 -25
  35. tme/structure.py +2 -37
  36. pytme-0.1.9.data/scripts/postprocess.py +0 -625
  37. pytme-0.1.9.dist-info/RECORD +0 -61
  38. {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/estimate_ram_usage.py +0 -0
  39. {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/preprocess.py +0 -0
  40. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/LICENSE +0 -0
  41. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/entry_points.txt +0 -0
  42. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,176 @@
1
+ """ Utilities for the generation of frequency grids.
2
+
3
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ from typing import Tuple
9
+
10
+ import numpy as np
11
+ from numpy.typing import NDArray
12
+
13
+ from ..backends import backend
14
+ from ..matching_utils import euler_to_rotationmatrix
15
+
16
+
17
+ def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool = False):
18
+ """
19
+ Given an opening_axis, computes the shape of the remaining dimensions.
20
+
21
+ Parameters:
22
+ -----------
23
+ shape : Tuple[int]
24
+ The shape of the input array.
25
+ opening_axis : int
26
+ The axis along which the array will be tilted.
27
+ reduce_dim : bool, optional (default=False)
28
+ Whether to reduce the dimensionality after tilting.
29
+
30
+ Returns:
31
+ --------
32
+ Tuple[int]
33
+ The shape of the array after tilting.
34
+ """
35
+ tilt_shape = tuple(x if i != opening_axis else 1 for i, x in enumerate(shape))
36
+ if reduce_dim:
37
+ tilt_shape = tuple(x for i, x in enumerate(shape) if i != opening_axis)
38
+
39
+ return tilt_shape
40
+
41
+
42
+ def centered_grid(shape: Tuple[int]) -> NDArray:
43
+ """
44
+ Generate an integer valued grid centered around size // 2
45
+
46
+ Parameters:
47
+ -----------
48
+ shape : Tuple[int]
49
+ The shape of the grid.
50
+
51
+ Returns:
52
+ --------
53
+ NDArray
54
+ The centered grid.
55
+ """
56
+ index_grid = np.array(
57
+ np.meshgrid(*[np.arange(size) - size // 2 for size in shape], indexing="ij")
58
+ )
59
+ return index_grid
60
+
61
+
62
+ def frequency_grid_at_angle(
63
+ shape: Tuple[int],
64
+ angle: float,
65
+ sampling_rate: Tuple[float],
66
+ opening_axis: int = None,
67
+ tilt_axis: int = None,
68
+ ) -> NDArray:
69
+ """
70
+ Generate a frequency grid from 0 to 1/(2 * sampling_rate) in each axis.
71
+
72
+ Parameters:
73
+ -----------
74
+ shape : Tuple[int]
75
+ The shape of the grid.
76
+ angle : float
77
+ The angle at which to generate the grid.
78
+ sampling_rate : Tuple[float]
79
+ The sampling rate for each dimension.
80
+ opening_axis : int, optional
81
+ The axis to be opened, defaults to None.
82
+ tilt_axis : int, optional
83
+ The axis along which the grid is tilted, defaults to None.
84
+
85
+ Returns:
86
+ --------
87
+ NDArray
88
+ The frequency grid.
89
+ """
90
+ sampling_rate = np.array(sampling_rate)
91
+ sampling_rate = np.repeat(sampling_rate, len(shape) // sampling_rate.size)
92
+
93
+ tilt_shape = compute_tilt_shape(
94
+ shape=shape, opening_axis=opening_axis, reduce_dim=False
95
+ )
96
+ index_grid = centered_grid(shape=tilt_shape)
97
+ if angle != 0:
98
+ angles = np.zeros(len(shape))
99
+ angles[tilt_axis] = angle
100
+ rotation_matrix = euler_to_rotationmatrix(np.roll(angles, opening_axis - 1))
101
+ index_grid = np.einsum("ij,j...->i...", rotation_matrix, index_grid)
102
+
103
+ norm = np.divide(1, 2 * sampling_rate * np.divide(shape, 2).astype(int))
104
+
105
+ index_grid = np.multiply(index_grid.T, norm).T
106
+ index_grid = np.squeeze(index_grid)
107
+ index_grid = np.linalg.norm(index_grid, axis=(0))
108
+ return index_grid
109
+
110
+
111
+ def fftfreqn(
112
+ shape: Tuple[int],
113
+ sampling_rate: Tuple[float],
114
+ compute_euclidean_norm: bool = False,
115
+ shape_is_real_fourier: bool = False,
116
+ ) -> NDArray:
117
+ """
118
+ Generate the n-dimensional discrete Fourier Transform sample frequencies.
119
+
120
+ Parameters:
121
+ -----------
122
+ shape : Tuple[int]
123
+ The shape of the data.
124
+ sampling_rate : float or Tuple[float]
125
+ The sampling rate.
126
+ compute_euclidean_norm : bool, optional
127
+ Whether to compute the Euclidean norm, defaults to False.
128
+ shape_is_real_fourier : bool, optional
129
+ Whether the shape corresponds to a real Fourier transform, defaults to False.
130
+
131
+ Returns:
132
+ --------
133
+ NDArray
134
+ The sample frequencies.
135
+ """
136
+ center = backend.astype(backend.divide(shape, 2), backend._default_dtype_int)
137
+
138
+ norm = np.ones(3)
139
+ if sampling_rate is not None:
140
+ norm = backend.multiply(shape, sampling_rate).astype(int)
141
+
142
+ if shape_is_real_fourier:
143
+ center[-1] = 0
144
+ norm[-1] = 1
145
+ if sampling_rate is not None:
146
+ norm[-1] = (shape[-1] - 1) * 2 * sampling_rate
147
+
148
+ indices = backend.transpose(backend.indices(shape))
149
+ indices -= center
150
+ indices = backend.divide(indices, norm)
151
+ indices = backend.transpose(indices)
152
+
153
+ if compute_euclidean_norm:
154
+ backend.square(indices, indices)
155
+ indices = backend.sum(indices, axis=0)
156
+ indices = backend.sqrt(indices)
157
+
158
+ return indices
159
+
160
+
161
+ def crop_real_fourier(data: NDArray) -> NDArray:
162
+ """
163
+ Crop the real part of a Fourier transform.
164
+
165
+ Parameters:
166
+ -----------
167
+ data : NDArray
168
+ The Fourier transformed data.
169
+
170
+ Returns:
171
+ --------
172
+ NDArray
173
+ The cropped data.
174
+ """
175
+ stop = 1 + (data.shape[-1] // 2)
176
+ return data[..., :stop]
@@ -0,0 +1,30 @@
1
+ """ Defines a specification for filters that can be used with
2
+ :py:class:`tme.preprocessing.compose.Compose`.
3
+
4
+ Copyright (c) 2024 European Molecular Biology Laboratory
5
+
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ """
8
+ from typing import Dict
9
+ from abc import ABC, abstractmethod
10
+
11
+ class ComposableFilter(ABC):
12
+ """
13
+ Strategy class for composable filters.
14
+ """
15
+
16
+ @abstractmethod
17
+ def __call__(self, *args, **kwargs) -> Dict:
18
+ """
19
+ Parameters:
20
+ -----------
21
+ *args : tuple
22
+ Variable length argument list.
23
+ **kwargs : dict
24
+ Arbitrary keyword arguments.
25
+
26
+ Returns:
27
+ --------
28
+ Dict
29
+ A dictionary representing the result of the filtering operation.
30
+ """
@@ -0,0 +1,52 @@
1
+ """ Combine filters using an interface analogous to pytorch's Compose.
2
+
3
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ from typing import Tuple, Dict
9
+
10
+ from tme.backends import backend
11
+
12
+
13
+ class Compose:
14
+ """
15
+ Compose a series of transformations.
16
+
17
+ This class allows composing multiple transformations together. Each transformation
18
+ is expected to be a callable that accepts keyword arguments and returns metadata.
19
+
20
+ Parameters:
21
+ -----------
22
+ transforms : Tuple[object]
23
+ A tuple containing transformation objects.
24
+
25
+ Returns:
26
+ --------
27
+ Dict
28
+ Metadata resulting from the composed transformations.
29
+
30
+ """
31
+
32
+ def __init__(self, transforms: Tuple[object]):
33
+ self.transforms = transforms
34
+
35
+ def __call__(self, **kwargs: Dict) -> Dict:
36
+ meta = {}
37
+ if not len(self.transforms):
38
+ return meta
39
+
40
+ meta = self.transforms[0](**kwargs)
41
+ for transform in self.transforms[1:]:
42
+
43
+ kwargs.update(meta)
44
+ ret = transform(**kwargs)
45
+
46
+ if ret.get("is_multiplicative_filter", False):
47
+ backend.multiply(ret["data"], meta["data"], ret["data"])
48
+ ret["merge"] = None
49
+
50
+ meta = ret
51
+
52
+ return meta
@@ -0,0 +1,322 @@
1
+ """ Defines Fourier frequency filters.
2
+
3
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+ from typing import Tuple, Dict
8
+
9
+ import numpy as np
10
+ from numpy.typing import NDArray
11
+ from scipy.ndimage import mean as ndimean
12
+
13
+ from ._utils import fftfreqn, crop_real_fourier
14
+ from ..backends import backend
15
+
16
+
17
+ class BandPassFilter:
18
+ """
19
+ This class provides methods to generate bandpass filters in Fourier space,
20
+ either by directly specifying the frequency cutoffs (discrete_bandpass) or
21
+ by using Gaussian functions (gaussian_bandpass).
22
+
23
+ Parameters:
24
+ -----------
25
+ lowpass : float, optional
26
+ The lowpass cutoff, defaults to None.
27
+ highpass : float, optional
28
+ The highpass cutoff, defaults to None.
29
+ sampling_rate : Tuple[float], optional
30
+ The sampling rate in Fourier space, defaults to 1.
31
+ use_gaussian : bool, optional
32
+ Whether to use Gaussian bandpass filter, defaults to True.
33
+ return_real_fourier : bool, optional
34
+ Whether to return only the real Fourier space, defaults to False.
35
+ shape_is_real_fourier : bool, optional
36
+ Whether the shape represents the real Fourier space, defaults to False.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ lowpass: float = None,
42
+ highpass: float = None,
43
+ sampling_rate: Tuple[float] = 1,
44
+ use_gaussian: bool = True,
45
+ return_real_fourier: bool = False,
46
+ shape_is_real_fourier: bool = False,
47
+ ):
48
+ self.lowpass = lowpass
49
+ self.highpass = highpass
50
+ self.use_gaussian = use_gaussian
51
+ self.return_real_fourier = return_real_fourier
52
+ self.shape_is_real_fourier = shape_is_real_fourier
53
+ self.sampling_rate = sampling_rate
54
+
55
+ @staticmethod
56
+ def discrete_bandpass(
57
+ shape: Tuple[int],
58
+ lowpass: float,
59
+ highpass: float,
60
+ sampling_rate: Tuple[float],
61
+ return_real_fourier: bool = False,
62
+ shape_is_real_fourier: bool = False,
63
+ **kwargs,
64
+ ) -> NDArray:
65
+ """
66
+ Generate a bandpass filter using discrete frequency cutoffs.
67
+
68
+ Parameters:
69
+ -----------
70
+ shape : tuple of int
71
+ The shape of the bandpass filter.
72
+ lowpass : float
73
+ The lowpass cutoff in units of sampling rate.
74
+ highpass : float
75
+ The highpass cutoff in units of sampling rate.
76
+ return_real_fourier : bool, optional
77
+ Whether to return only the real Fourier space, defaults to False.
78
+ sampling_rate : float
79
+ The sampling rate in Fourier space.
80
+ shape_is_real_fourier : bool, optional
81
+ Whether the shape represents the real Fourier space, defaults to False.
82
+ **kwargs : dict
83
+ Additional keyword arguments.
84
+
85
+ Returns:
86
+ --------
87
+ NDArray
88
+ The bandpass filter in Fourier space.
89
+ """
90
+ grid = fftfreqn(
91
+ shape=shape,
92
+ sampling_rate=0.5,
93
+ shape_is_real_fourier=shape_is_real_fourier,
94
+ compute_euclidean_norm=True,
95
+ )
96
+
97
+ lowpass = 0 if lowpass is None else lowpass
98
+ highpass = 1e10 if highpass is None else highpass
99
+
100
+ highcut = grid.max()
101
+ if lowpass > 0:
102
+ highcut = np.max(2 * sampling_rate / lowpass)
103
+ lowcut = np.max(2 * sampling_rate / highpass)
104
+
105
+ bandpass_filter = ((grid <= highcut) & (grid >= lowcut)) * 1.0
106
+ shift = backend.add(
107
+ backend.astype(backend.divide(bandpass_filter.shape, 2), int),
108
+ backend.mod(bandpass_filter.shape, 2),
109
+ )
110
+ if shape_is_real_fourier:
111
+ shift[-1] = 0
112
+
113
+ bandpass_filter = backend.roll(
114
+ bandpass_filter, shift, tuple(i for i in range(len(shift)))
115
+ )
116
+
117
+ if return_real_fourier:
118
+ bandpass_filter = crop_real_fourier(bandpass_filter)
119
+
120
+ return bandpass_filter
121
+
122
+ @staticmethod
123
+ def gaussian_bandpass(
124
+ shape: Tuple[int],
125
+ lowpass: float,
126
+ highpass: float,
127
+ sampling_rate: float,
128
+ return_real_fourier: bool = False,
129
+ shape_is_real_fourier: bool = False,
130
+ **kwargs,
131
+ ) -> NDArray:
132
+ """
133
+ Generate a bandpass filter using Gaussian functions.
134
+
135
+ Parameters:
136
+ -----------
137
+ shape : tuple of int
138
+ The shape of the bandpass filter.
139
+ lowpass : float
140
+ The lowpass cutoff in units of sampling rate.
141
+ highpass : float
142
+ The highpass cutoff in units of sampling rate.
143
+ sampling_rate : float
144
+ The sampling rate in Fourier space.
145
+ return_real_fourier : bool, optional
146
+ Whether to return only the real Fourier space, defaults to False.
147
+ shape_is_real_fourier : bool, optional
148
+ Whether the shape represents the real Fourier space, defaults to False.
149
+ **kwargs : dict
150
+ Additional keyword arguments.
151
+
152
+ Returns:
153
+ --------
154
+ NDArray
155
+ The bandpass filter in Fourier space.
156
+ """
157
+ if shape_is_real_fourier:
158
+ return_real_fourier = False
159
+
160
+ grid = fftfreqn(
161
+ shape=shape,
162
+ sampling_rate=0.5,
163
+ shape_is_real_fourier=shape_is_real_fourier,
164
+ compute_euclidean_norm=True,
165
+ )
166
+ grid = -backend.square(grid)
167
+
168
+ lowpass_filter, highpass_filter = 1, 1
169
+ norm = float(backend.sqrt(2 * backend.log(2)))
170
+ upper_sampling = float(backend.max(backend.multiply(2, sampling_rate)))
171
+
172
+ if lowpass is not None:
173
+ lowpass = float(lowpass)
174
+ lowpass = backend.maximum(lowpass, backend.eps(lowpass))
175
+ if highpass is not None:
176
+ highpass = float(highpass)
177
+ highpass = backend.maximum(highpass, backend.eps(highpass))
178
+
179
+ if lowpass is not None:
180
+ lowpass = upper_sampling / (lowpass * norm)
181
+ lowpass = backend.multiply(2, backend.square(lowpass))
182
+ lowpass_filter = backend.exp(backend.divide(grid, lowpass))
183
+ if highpass is not None:
184
+ highpass = upper_sampling / (highpass * norm)
185
+ highpass = backend.multiply(2, backend.square(highpass))
186
+ highpass_filter = 1 - backend.exp(backend.divide(grid, highpass))
187
+
188
+ lowpass_filter = backend.multiply(lowpass_filter, highpass_filter)
189
+ shift = backend.add(
190
+ backend.astype(backend.divide(lowpass_filter.shape, 2), int),
191
+ backend.mod(lowpass_filter.shape, 2),
192
+ )
193
+ if shape_is_real_fourier:
194
+ shift[-1] = 0
195
+
196
+ lowpass_filter = backend.roll(
197
+ lowpass_filter, shift, tuple(i for i in range(len(shift)))
198
+ )
199
+
200
+ if return_real_fourier:
201
+ lowpass_filter = crop_real_fourier(lowpass_filter)
202
+
203
+ return lowpass_filter
204
+
205
+ def __call__(self, **kwargs):
206
+ func_args = vars(self)
207
+ func_args.update(kwargs)
208
+
209
+ func = self.discrete_bandpass
210
+ if func_args.get("use_gaussian"):
211
+ func = self.gaussian_bandpass
212
+
213
+ mask = func(**func_args)
214
+
215
+ return {
216
+ "data": backend.to_backend_array(mask),
217
+ "sampling_rate": func_args.get("sampling_rate", 1),
218
+ "is_multiplicative_filter": True,
219
+ }
220
+
221
+
222
+ class LinearWhiteningFilter:
223
+ """
224
+ This class provides methods to compute the spectrum of the input data and
225
+ apply linear whitening to the Fourier coefficients.
226
+
227
+ Parameters:
228
+ -----------
229
+ **kwargs : Dict, optional
230
+ Additional keyword arguments.
231
+ """
232
+
233
+ def __init__(self, **kwargs):
234
+ pass
235
+
236
+ @staticmethod
237
+ def _compute_spectrum(
238
+ data_rfft: NDArray, n_bins: int = None
239
+ ) -> Tuple[NDArray, NDArray]:
240
+ """
241
+ Compute the spectrum of the input data.
242
+
243
+ Parameters:
244
+ -----------
245
+ data_rfft : NDArray
246
+ The Fourier transform of the input data.
247
+ n_bins : int, optional
248
+ The number of bins for computing the spectrum, defaults to None.
249
+
250
+ Returns:
251
+ --------
252
+ bins : NDArray
253
+ Array containing the bin indices for the spectrum.
254
+ radial_averages : NDArray
255
+ Array containing the radial averages of the spectrum.
256
+ """
257
+ max_bins = max(max(data_rfft.shape[:-1]) // 2 + 1, data_rfft.shape[-1])
258
+ n_bins = max_bins if n_bins is None else n_bins
259
+ n_bins = int(min(n_bins, max_bins))
260
+
261
+ grid = fftfreqn(
262
+ shape=data_rfft.shape,
263
+ sampling_rate=None,
264
+ shape_is_real_fourier=True,
265
+ compute_euclidean_norm=True,
266
+ )
267
+ _, bin_edges = np.histogram(grid, bins=n_bins - 1)
268
+ bins = np.digitize(grid, bins=bin_edges, right=True)
269
+
270
+ fft_shift_axes = tuple(range(data_rfft.ndim - 1))
271
+ fourier_transform = np.fft.fftshift(data_rfft, axes=fft_shift_axes)
272
+ fourier_spectrum = np.square(np.abs(fourier_transform))
273
+ radial_averages = ndimean(fourier_spectrum, labels=bins, index=np.unique(bins))
274
+
275
+ np.sqrt(radial_averages, out=radial_averages)
276
+ np.reciprocal(radial_averages, out=radial_averages)
277
+ np.divide(radial_averages, radial_averages.max(), out=radial_averages)
278
+
279
+ return bins, radial_averages
280
+
281
+ def __call__(
282
+ self,
283
+ data: NDArray = None,
284
+ data_rfft: NDArray = None,
285
+ n_bins: int = None,
286
+ **kwargs: Dict,
287
+ ) -> Dict:
288
+ """
289
+ Apply linear whitening to the data and return the result.
290
+
291
+ Parameters:
292
+ -----------
293
+ data : NDArray, optional
294
+ The input data, defaults to None.
295
+ data_rfft : NDArray, optional
296
+ The Fourier transform of the input data, defaults to None.
297
+ n_bins : int, optional
298
+ The number of bins for computing the spectrum, defaults to None.
299
+ **kwargs : Dict
300
+ Additional keyword arguments.
301
+
302
+ Returns:
303
+ --------
304
+ Dict
305
+ A dictionary containing the whitened data and information
306
+ about the filter being a multiplicative filter.
307
+ """
308
+ if data_rfft is None:
309
+ data_rfft = np.fft.rfftn(backend.to_numpy_array(data))
310
+
311
+ data_rfft = backend.to_numpy_array(data_rfft)
312
+
313
+ bins, radial_averages = self._compute_spectrum(data_rfft, n_bins)
314
+
315
+ radial_averages = np.fft.ifftshift(
316
+ radial_averages[bins], axes=tuple(range(data_rfft.ndim - 1))
317
+ )
318
+
319
+ return {
320
+ "data": backend.to_backend_array(radial_averages),
321
+ "is_multiplicative_filter": True,
322
+ }