pytme 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2.dev0__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.2.dev0.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/match_template.py +213 -196
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/postprocess.py +40 -78
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/preprocess.py +4 -5
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/preprocessor_gui.py +49 -103
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/pytme_runner.py +46 -69
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/METADATA +2 -1
- pytme-0.3.2.dev0.dist-info/RECORD +136 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +213 -196
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +40 -78
- scripts/preprocess.py +4 -5
- scripts/preprocessor_gui.py +49 -103
- scripts/pytme_runner.py +46 -69
- tests/preprocessing/test_compose.py +31 -30
- tests/preprocessing/test_frequency_filters.py +17 -32
- tests/preprocessing/test_preprocessor.py +0 -19
- tests/preprocessing/test_utils.py +13 -1
- tests/test_analyzer.py +2 -10
- tests/test_backends.py +47 -18
- tests/test_density.py +72 -13
- tests/test_extensions.py +1 -0
- tests/test_matching_cli.py +23 -9
- tests/test_matching_exhaustive.py +5 -5
- tests/test_matching_utils.py +3 -3
- tests/test_orientations.py +12 -0
- tests/test_rotations.py +13 -23
- tests/test_structure.py +1 -7
- tme/__version__.py +1 -1
- tme/analyzer/aggregation.py +47 -16
- tme/analyzer/base.py +34 -0
- tme/analyzer/peaks.py +26 -13
- tme/analyzer/proxy.py +14 -0
- tme/backends/_jax_utils.py +91 -68
- tme/backends/cupy_backend.py +6 -19
- tme/backends/jax_backend.py +103 -98
- tme/backends/matching_backend.py +0 -17
- tme/backends/mlx_backend.py +0 -29
- tme/backends/npfftw_backend.py +100 -97
- tme/backends/pytorch_backend.py +65 -78
- tme/cli.py +2 -2
- tme/density.py +44 -57
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/_utils.py +52 -24
- tme/filters/bandpass.py +99 -105
- tme/filters/compose.py +133 -39
- tme/filters/ctf.py +51 -102
- tme/filters/reconstruction.py +67 -122
- tme/filters/wedge.py +296 -325
- tme/filters/whitening.py +39 -75
- tme/mask.py +2 -2
- tme/matching_data.py +87 -15
- tme/matching_exhaustive.py +70 -120
- tme/matching_optimization.py +9 -63
- tme/matching_scores.py +261 -100
- tme/matching_utils.py +150 -91
- tme/memory.py +1 -0
- tme/orientations.py +17 -3
- tme/preprocessor.py +0 -239
- tme/rotations.py +102 -70
- tme/structure.py +601 -631
- tme/types.py +1 -0
- pytme-0.3.1.post2.dist-info/RECORD +0 -133
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/WHEEL +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/entry_points.txt +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/top_level.txt +0 -0
tme/filters/whitening.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
"""
|
2
|
-
Implements class
|
2
|
+
Implements class LinearWhiteningFilter
|
3
3
|
|
4
4
|
Copyright (c) 2024 European Molecular Biology Laboratory
|
5
5
|
|
@@ -7,22 +7,26 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
|
7
7
|
"""
|
8
8
|
|
9
9
|
from typing import Tuple, Dict
|
10
|
+
from dataclasses import dataclass
|
10
11
|
|
11
12
|
import numpy as np
|
12
13
|
from scipy.ndimage import mean as ndimean
|
13
14
|
from scipy.ndimage import map_coordinates
|
14
15
|
|
16
|
+
from ._utils import fftfreqn
|
15
17
|
from ..types import BackendArray
|
18
|
+
from ..analyzer.peaks import batchify
|
16
19
|
from ..backends import backend as be
|
17
20
|
from .compose import ComposableFilter
|
18
|
-
|
21
|
+
|
19
22
|
|
20
23
|
__all__ = ["LinearWhiteningFilter"]
|
21
24
|
|
22
25
|
|
26
|
+
@dataclass
|
23
27
|
class LinearWhiteningFilter(ComposableFilter):
|
24
28
|
"""
|
25
|
-
|
29
|
+
Generate Fourier whitening filters.
|
26
30
|
|
27
31
|
References
|
28
32
|
----------
|
@@ -34,12 +38,9 @@ class LinearWhiteningFilter(ComposableFilter):
|
|
34
38
|
13375 (2023)
|
35
39
|
"""
|
36
40
|
|
37
|
-
def __init__(self, *args, **kwargs):
|
38
|
-
pass
|
39
|
-
|
40
41
|
@staticmethod
|
41
42
|
def _compute_spectrum(
|
42
|
-
data_rfft: BackendArray, n_bins: int = None
|
43
|
+
data_rfft: BackendArray, n_bins: int = None
|
43
44
|
) -> Tuple[BackendArray, BackendArray]:
|
44
45
|
"""
|
45
46
|
Compute the power spectrum of the input data.
|
@@ -50,8 +51,6 @@ class LinearWhiteningFilter(ComposableFilter):
|
|
50
51
|
The Fourier transform of the input data.
|
51
52
|
n_bins : int, optional
|
52
53
|
The number of bins for computing the spectrum, defaults to None.
|
53
|
-
batch_dimension : int, optional
|
54
|
-
Batch dimension to average over.
|
55
54
|
|
56
55
|
Returns
|
57
56
|
-------
|
@@ -60,7 +59,7 @@ class LinearWhiteningFilter(ComposableFilter):
|
|
60
59
|
radial_averages : BackendArray
|
61
60
|
Array containing the radial averages of the spectrum.
|
62
61
|
"""
|
63
|
-
shape =
|
62
|
+
shape = data_rfft.shape
|
64
63
|
|
65
64
|
max_bins = max(max(shape[:-1]) // 2 + 1, shape[-1])
|
66
65
|
n_bins = max_bins if n_bins is None else n_bins
|
@@ -71,25 +70,22 @@ class LinearWhiteningFilter(ComposableFilter):
|
|
71
70
|
sampling_rate=0.5,
|
72
71
|
shape_is_real_fourier=True,
|
73
72
|
compute_euclidean_norm=True,
|
73
|
+
fftshift=False,
|
74
74
|
)
|
75
75
|
bins = be.to_numpy_array(bins)
|
76
|
-
|
77
|
-
# Implicit lowpass to nyquist
|
78
76
|
bins = np.floor(bins * (n_bins - 1) + 0.5).astype(int)
|
79
|
-
|
80
|
-
|
81
|
-
)
|
82
|
-
fourier_spectrum = np.fft.fftshift(data_rfft, axes=fft_shift_axes)
|
83
|
-
fourier_spectrum = np.abs(fourier_spectrum)
|
84
|
-
np.square(fourier_spectrum, out=fourier_spectrum)
|
77
|
+
|
78
|
+
fourier_spectrum = np.abs(data_rfft)
|
79
|
+
fourier_spectrum = np.square(fourier_spectrum, out=fourier_spectrum)
|
85
80
|
|
86
81
|
radial_averages = ndimean(
|
87
82
|
fourier_spectrum, labels=bins, index=np.arange(n_bins)
|
88
83
|
)
|
89
|
-
np.sqrt(radial_averages, out=radial_averages)
|
90
|
-
np.
|
91
|
-
|
92
|
-
|
84
|
+
radial_averages = np.sqrt(radial_averages, out=radial_averages)
|
85
|
+
radial_averages = np.where(radial_averages != 0, 1 / radial_averages, 0)
|
86
|
+
norm_factor = radial_averages.max()
|
87
|
+
if norm_factor != 0:
|
88
|
+
radial_averages = np.divide(radial_averages, norm_factor)
|
93
89
|
return bins, radial_averages
|
94
90
|
|
95
91
|
@staticmethod
|
@@ -104,21 +100,19 @@ class LinearWhiteningFilter(ComposableFilter):
|
|
104
100
|
sampling_rate=0.5,
|
105
101
|
shape_is_real_fourier=shape_is_real_fourier,
|
106
102
|
compute_euclidean_norm=True,
|
103
|
+
fftshift=False,
|
107
104
|
)
|
108
105
|
grid = be.to_numpy_array(grid)
|
109
|
-
np.multiply(grid,
|
106
|
+
grid = np.floor(np.multiply(grid, spectrum.shape[0] - 1) + 0.5)
|
110
107
|
spectrum = map_coordinates(spectrum, grid.reshape(1, -1), order=order)
|
111
108
|
return spectrum.reshape(grid.shape)
|
112
109
|
|
113
|
-
def
|
110
|
+
def _evaluate(
|
114
111
|
self,
|
115
|
-
shape: Tuple[int],
|
116
|
-
|
117
|
-
|
118
|
-
n_bins: int = None,
|
119
|
-
batch_dimension: int = None,
|
112
|
+
shape: Tuple[int, ...],
|
113
|
+
data_rfft: BackendArray,
|
114
|
+
axes: Tuple[int] = (),
|
120
115
|
order: int = 1,
|
121
|
-
return_real_fourier: bool = True,
|
122
116
|
**kwargs: Dict,
|
123
117
|
) -> Dict:
|
124
118
|
"""
|
@@ -128,59 +122,29 @@ class LinearWhiteningFilter(ComposableFilter):
|
|
128
122
|
----------
|
129
123
|
shape : tuple of ints
|
130
124
|
Shape of the returned whitening filter.
|
131
|
-
data : BackendArray, optional
|
132
|
-
The input data, defaults to None.
|
133
125
|
data_rfft : BackendArray, optional
|
134
126
|
The Fourier transform of the input data, defaults to None.
|
135
|
-
|
136
|
-
|
137
|
-
batch_dimension : int, optional
|
138
|
-
Batch dimension to average over.
|
139
|
-
return_real_fourier : tuple of int
|
140
|
-
Return a shape compliant with rfft, i.e., omit the negative frequencies
|
141
|
-
terms resulting in a return shape (*shape[:-1], shape[-1]//2+1)
|
127
|
+
axes : tuple of ints, optional
|
128
|
+
Axes to compute spectrum for independently.
|
142
129
|
**kwargs : Dict
|
143
130
|
Additional keyword arguments.
|
144
|
-
|
145
|
-
Returns
|
146
|
-
-------
|
147
|
-
dict
|
148
|
-
data: BackendArray
|
149
|
-
The filter mask.
|
150
|
-
shape: tuple of ints
|
151
|
-
The requested filter shape
|
152
|
-
return_real_fourier: bool
|
153
|
-
Whether data is compliant with rfftn.
|
154
|
-
is_multiplicative_filter: bool
|
155
|
-
Whether the filter is multiplicative in Fourier space.
|
156
131
|
"""
|
157
|
-
if
|
158
|
-
|
132
|
+
if isinstance(axes, int):
|
133
|
+
axes = (axes,)
|
159
134
|
|
135
|
+
stack = []
|
160
136
|
data_rfft = be.to_numpy_array(data_rfft)
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
shape_filter = shape
|
167
|
-
if return_real_fourier:
|
168
|
-
shape_filter = compute_fourier_shape(
|
137
|
+
for subset, _ in batchify(data_rfft.shape, axes):
|
138
|
+
_, radial_avg = self._compute_spectrum(np.squeeze(data_rfft[subset]))
|
139
|
+
ret = self._interpolate_spectrum(
|
140
|
+
spectrum=radial_avg,
|
169
141
|
shape=shape,
|
170
142
|
shape_is_real_fourier=False,
|
143
|
+
order=order,
|
171
144
|
)
|
145
|
+
stack.append(ret)
|
172
146
|
|
173
|
-
ret =
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
)
|
178
|
-
|
179
|
-
ret = shift_fourier(data=ret, shape_is_real_fourier=return_real_fourier)
|
180
|
-
|
181
|
-
return {
|
182
|
-
"data": be.to_backend_array(ret),
|
183
|
-
"shape": shape,
|
184
|
-
"return_real_fourier": return_real_fourier,
|
185
|
-
"is_multiplicative_filter": True,
|
186
|
-
}
|
147
|
+
ret = np.array(stack)
|
148
|
+
if not len(axes):
|
149
|
+
ret = np.squeeze(ret)
|
150
|
+
return {"data": be.to_backend_array(ret), "shape": shape}
|
tme/mask.py
CHANGED
@@ -11,7 +11,7 @@ from typing import Tuple, Optional
|
|
11
11
|
|
12
12
|
from .types import NDArray
|
13
13
|
from scipy.ndimage import gaussian_filter
|
14
|
-
from .matching_utils import
|
14
|
+
from .matching_utils import _rigid_transform
|
15
15
|
|
16
16
|
__all__ = ["elliptical_mask", "tube_mask", "box_mask", "membrane_mask"]
|
17
17
|
|
@@ -76,7 +76,7 @@ def elliptical_mask(
|
|
76
76
|
if orientation is not None:
|
77
77
|
return_shape = indices.shape
|
78
78
|
indices = indices.reshape(n, -1)
|
79
|
-
|
79
|
+
_rigid_transform(
|
80
80
|
coordinates=indices,
|
81
81
|
rotation_matrix=np.asarray(orientation),
|
82
82
|
out=indices,
|
tme/matching_data.py
CHANGED
@@ -15,7 +15,7 @@ from . import Density
|
|
15
15
|
from .filters import Compose
|
16
16
|
from .backends import backend as be
|
17
17
|
from .types import BackendArray, NDArray
|
18
|
-
from .matching_utils import compute_parallelization_schedule
|
18
|
+
from .matching_utils import compute_parallelization_schedule, copy_docstring
|
19
19
|
|
20
20
|
__all__ = ["MatchingData"]
|
21
21
|
|
@@ -249,8 +249,8 @@ class MatchingData:
|
|
249
249
|
target_offset = np.zeros(len(self._output_target_shape), dtype=int)
|
250
250
|
target_offset[mask] = [x.start for x in target_slice]
|
251
251
|
mask = np.subtract(1, self._target_batch).astype(bool)
|
252
|
-
template_offset = np.zeros(len(self._output_template_shape), dtype=int)
|
253
|
-
template_offset[mask] = [x.start for x in template_slice]
|
252
|
+
# template_offset = np.zeros(len(self._output_template_shape), dtype=int)
|
253
|
+
# template_offset[mask] = [x.start for x in template_slice]
|
254
254
|
|
255
255
|
translation_offset = tuple(x for x in target_offset)
|
256
256
|
|
@@ -485,14 +485,18 @@ class MatchingData:
|
|
485
485
|
|
486
486
|
@staticmethod
|
487
487
|
def _fourier_padding(
|
488
|
-
target_shape:
|
489
|
-
template_shape:
|
490
|
-
|
488
|
+
target_shape: NDArray,
|
489
|
+
template_shape: NDArray,
|
490
|
+
target_batch: NDArray = None,
|
491
|
+
template_batch: NDArray = None,
|
491
492
|
**kwargs,
|
492
493
|
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
|
493
|
-
if
|
494
|
-
|
495
|
-
|
494
|
+
if target_batch is None:
|
495
|
+
target_batch = np.zeros_like(target_shape)
|
496
|
+
if template_batch is None:
|
497
|
+
template_batch = np.zeros_like(target_shape)
|
498
|
+
|
499
|
+
batch_mask = np.logical_or(target_batch, template_batch)
|
496
500
|
|
497
501
|
fourier_pad = np.ones(len(template_shape), dtype=int)
|
498
502
|
fourier_pad = np.multiply(fourier_pad, 1 - batch_mask)
|
@@ -500,7 +504,9 @@ class MatchingData:
|
|
500
504
|
|
501
505
|
# Avoid padding batch dimensions
|
502
506
|
pad_shape = np.maximum(target_shape, template_shape)
|
503
|
-
pad_shape = np.
|
507
|
+
pad_shape = np.where(target_batch, target_shape, pad_shape)
|
508
|
+
pad_shape = np.where(template_batch, template_shape, pad_shape)
|
509
|
+
|
504
510
|
ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
|
505
511
|
conv_shape, fast_shape, fast_ft_shape = ret
|
506
512
|
|
@@ -538,10 +544,15 @@ class MatchingData:
|
|
538
544
|
--------
|
539
545
|
>>> conv, fwd, inv, shift = matching_data.fourier_padding(pad_fourier=True)
|
540
546
|
"""
|
547
|
+
target_shape = kwargs.get("target_shape", self._output_target_shape)
|
548
|
+
template_shape = kwargs.get("template_shape", self._output_template_shape)
|
549
|
+
target_batch = kwargs.get("target_batch", self._target_batch)
|
550
|
+
template_batch = kwargs.get("template_batch", self._template_batch)
|
541
551
|
return self._fourier_padding(
|
542
|
-
target_shape=be.to_numpy_array(
|
543
|
-
template_shape=be.to_numpy_array(
|
544
|
-
|
552
|
+
target_shape=be.to_numpy_array(target_shape),
|
553
|
+
template_shape=be.to_numpy_array(template_shape),
|
554
|
+
target_batch=be.to_numpy_array(target_batch),
|
555
|
+
template_batch=be.to_numpy_array(template_batch),
|
545
556
|
)
|
546
557
|
|
547
558
|
def _score_mask(self, fast_shape: Tuple[int], shift: Tuple[int]) -> BackendArray:
|
@@ -568,6 +579,68 @@ class MatchingData:
|
|
568
579
|
)
|
569
580
|
return be.to_backend_array(score_mask)
|
570
581
|
|
582
|
+
def _transform_data(
|
583
|
+
self, method: str, data: BackendArray, batch_mask: Tuple[int], **kwargs
|
584
|
+
) -> BackendArray:
|
585
|
+
"""
|
586
|
+
Transform data using the specified method.
|
587
|
+
|
588
|
+
Parameters
|
589
|
+
----------
|
590
|
+
method : str, optional
|
591
|
+
Transformation method, default "phase_randomization".
|
592
|
+
- "phase_randomization": Scrambles phase while preserving amplitude spectrum
|
593
|
+
- "standardize": Standardize to zero mean and unit variance
|
594
|
+
- "laplace": Applies Laplacian edge detection filter
|
595
|
+
**kwargs : dict
|
596
|
+
Method-specific arguments (e.g., mode="wrap" for laplace).
|
597
|
+
|
598
|
+
Returns
|
599
|
+
-------
|
600
|
+
BackendArray
|
601
|
+
Transformed data.
|
602
|
+
"""
|
603
|
+
from scipy.ndimage import laplace
|
604
|
+
from .matching_utils import scramble_phases, standardize
|
605
|
+
|
606
|
+
def _standardize(arr: NDArray, **kwargs) -> NDArray:
|
607
|
+
return standardize(arr, 1, arr.size)
|
608
|
+
|
609
|
+
_supported_methods = {
|
610
|
+
"phase_randomization": scramble_phases,
|
611
|
+
"laplace": laplace,
|
612
|
+
"standardize": _standardize,
|
613
|
+
}
|
614
|
+
func = _supported_methods.get(method)
|
615
|
+
if func is None:
|
616
|
+
_supported = ",".join([str(x) for x in _supported_methods])
|
617
|
+
raise ValueError(f"Only methods {_supported} are supported.")
|
618
|
+
|
619
|
+
data = be.to_numpy_array(data)
|
620
|
+
|
621
|
+
ret = np.zeros_like(data)
|
622
|
+
for subset in self._batch_iter(data.shape, batch_mask):
|
623
|
+
ret[subset] = func(data[subset], **kwargs)
|
624
|
+
return be.to_backend_array(ret)
|
625
|
+
|
626
|
+
@copy_docstring(_transform_data)
|
627
|
+
def transform_target(self, method: str = "phase_randomization", **kwargs):
|
628
|
+
return self._transform_data(method, self.target, self._target_batch, **kwargs)
|
629
|
+
|
630
|
+
@copy_docstring(_transform_data)
|
631
|
+
def transform_template(
|
632
|
+
self, method: str = "phase_randomization", reverse: bool = False, **kwargs
|
633
|
+
):
|
634
|
+
"""
|
635
|
+
Notes
|
636
|
+
-----
|
637
|
+
The returned template is in the original not reversed orientation.
|
638
|
+
"""
|
639
|
+
template = self._get_data(
|
640
|
+
self._template, self._output_template_shape, reverse, self._template_dim
|
641
|
+
)
|
642
|
+
return self._transform_data(method, template, self._template_batch, **kwargs)
|
643
|
+
|
571
644
|
def computation_schedule(
|
572
645
|
self,
|
573
646
|
matching_method: str = "FLCSphericalMask",
|
@@ -723,7 +796,6 @@ class MatchingData:
|
|
723
796
|
_output_shape = self._output_template_shape
|
724
797
|
if np.prod([int(i) for i in template_mask.shape]) != np.prod(_output_shape):
|
725
798
|
_output_shape = self._batch_shape(_output_shape, self._template_batch, True)
|
726
|
-
|
727
799
|
return self._get_data(template_mask, _output_shape, True, self._template_dim)
|
728
800
|
|
729
801
|
@target.setter
|
@@ -855,7 +927,7 @@ class MatchingData:
|
|
855
927
|
rot_list.append(self.rotations[init_rot:end_rot])
|
856
928
|
return rot_list
|
857
929
|
|
858
|
-
def
|
930
|
+
def free(self):
|
859
931
|
"""
|
860
932
|
Dereference data arrays owned by the class instance.
|
861
933
|
"""
|