pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/estimate_memory_usage.py +1 -5
- {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/match_template.py +163 -201
- {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/postprocess.py +48 -39
- {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/preprocess.py +10 -23
- {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/preprocessor_gui.py +3 -4
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +769 -0
- {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/METADATA +14 -14
- {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/RECORD +54 -50
- {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/entry_points.txt +1 -0
- pytme-0.3b0.post1.dist-info/licenses/LICENSE +339 -0
- scripts/estimate_memory_usage.py +1 -5
- scripts/eval.py +93 -0
- scripts/match_template.py +163 -201
- scripts/match_template_filters.py +1200 -0
- scripts/postprocess.py +48 -39
- scripts/preprocess.py +10 -23
- scripts/preprocessor_gui.py +3 -4
- scripts/pytme_runner.py +769 -0
- scripts/refine_matches.py +0 -1
- tests/preprocessing/test_frequency_filters.py +19 -10
- tests/test_analyzer.py +122 -122
- tests/test_backends.py +1 -0
- tests/test_matching_cli.py +30 -30
- tests/test_matching_data.py +5 -5
- tests/test_matching_utils.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +1 -1
- tme/analyzer/_utils.py +1 -4
- tme/analyzer/aggregation.py +15 -6
- tme/analyzer/base.py +25 -36
- tme/analyzer/peaks.py +39 -113
- tme/analyzer/proxy.py +1 -0
- tme/backends/_jax_utils.py +16 -15
- tme/backends/cupy_backend.py +9 -13
- tme/backends/jax_backend.py +19 -16
- tme/backends/npfftw_backend.py +27 -25
- tme/backends/pytorch_backend.py +4 -0
- tme/density.py +5 -4
- tme/filters/__init__.py +2 -2
- tme/filters/_utils.py +32 -7
- tme/filters/bandpass.py +225 -186
- tme/filters/ctf.py +117 -67
- tme/filters/reconstruction.py +38 -9
- tme/filters/wedge.py +88 -105
- tme/filters/whitening.py +1 -6
- tme/matching_data.py +24 -36
- tme/matching_exhaustive.py +14 -11
- tme/matching_scores.py +21 -12
- tme/matching_utils.py +13 -6
- tme/orientations.py +13 -3
- tme/parser.py +109 -29
- tme/preprocessor.py +2 -2
- pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
- {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/top_level.txt +0 -0
tme/analyzer/base.py
CHANGED
@@ -38,16 +38,16 @@ class AbstractAnalyzer(ABC):
|
|
38
38
|
|
39
39
|
Returns
|
40
40
|
-------
|
41
|
-
state
|
42
|
-
Initial state tuple
|
43
|
-
|
44
|
-
implementation.
|
41
|
+
state : tuple
|
42
|
+
Initial state tuple of the analyzer instance. The exact structure
|
43
|
+
depends on the specific implementation.
|
45
44
|
|
46
45
|
Notes
|
47
46
|
-----
|
48
47
|
This method creates the initial state that will be passed to
|
49
|
-
|
50
|
-
|
48
|
+
:py:meth:`AbstractAnalyzer.__call__` and finally to
|
49
|
+
:py:meth:`AbstractAnalyzer.result`. The state should contain all necessary
|
50
|
+
data structures for accumulating analysis results.
|
51
51
|
"""
|
52
52
|
|
53
53
|
@abstractmethod
|
@@ -57,49 +57,39 @@ class AbstractAnalyzer(ABC):
|
|
57
57
|
|
58
58
|
Parameters
|
59
59
|
----------
|
60
|
-
state :
|
61
|
-
Current analyzer state as returned
|
62
|
-
previous
|
60
|
+
state : tuple
|
61
|
+
Current analyzer state as returned :py:meth:`AbstractAnalyzer.init_state`
|
62
|
+
or previous invocations of :py:meth:`AbstractAnalyzer.__call__`.
|
63
63
|
scores : BackendArray
|
64
|
-
Array of scores
|
64
|
+
Array of new scores with dimensionality d.
|
65
65
|
rotation_matrix : BackendArray
|
66
|
-
Rotation matrix used to generate
|
66
|
+
Rotation matrix used to generate scores with shape (d,d).
|
67
67
|
**kwargs : dict
|
68
|
-
|
69
|
-
implementation.
|
68
|
+
Keyword arguments used by specific implementations.
|
70
69
|
|
71
70
|
Returns
|
72
71
|
-------
|
73
|
-
|
74
|
-
Updated analyzer state
|
75
|
-
|
76
|
-
Notes
|
77
|
-
-----
|
78
|
-
This method should be pure functional - it should not modify
|
79
|
-
the input state but return a new state with the updates applied.
|
80
|
-
The exact signature may vary between implementations.
|
72
|
+
tuple
|
73
|
+
Updated analyzer state incorporating the new data.
|
81
74
|
"""
|
82
|
-
pass
|
83
75
|
|
84
76
|
@abstractmethod
|
85
77
|
def result(self, state: Tuple, **kwargs) -> Tuple:
|
86
78
|
"""
|
87
|
-
Finalize the analysis
|
79
|
+
Finalize the analysis by performing potential post processing.
|
88
80
|
|
89
81
|
Parameters
|
90
82
|
----------
|
91
83
|
state : tuple
|
92
|
-
|
84
|
+
Analyzer state containing accumulated data.
|
93
85
|
**kwargs : dict
|
94
|
-
|
95
|
-
such as postprocessing parameters.
|
86
|
+
Keyword arguments used by specific implementations.
|
96
87
|
|
97
88
|
Returns
|
98
89
|
-------
|
99
90
|
result
|
100
|
-
Final analysis result. The exact
|
101
|
-
analyzer implementation
|
102
|
-
scores, rotation information, and metadata.
|
91
|
+
Final analysis result. The exact struccture depends on the
|
92
|
+
analyzer implementation.
|
103
93
|
|
104
94
|
Notes
|
105
95
|
-----
|
@@ -108,25 +98,24 @@ class AbstractAnalyzer(ABC):
|
|
108
98
|
It may apply postprocessing operations like convolution mode
|
109
99
|
correction or coordinate transformations.
|
110
100
|
"""
|
111
|
-
pass
|
112
101
|
|
113
102
|
@classmethod
|
114
103
|
@abstractmethod
|
115
104
|
def merge(cls, results: List[Tuple], **kwargs) -> Tuple:
|
116
105
|
"""
|
117
|
-
Merge
|
106
|
+
Merge multiple analyzer results.
|
118
107
|
|
119
108
|
Parameters
|
120
109
|
----------
|
121
|
-
results : list
|
122
|
-
List of
|
123
|
-
from
|
110
|
+
results : list of tuple
|
111
|
+
List of tuple objects returned by :py:meth:`AbstractAnalyzer.result`
|
112
|
+
from different instances of the same analyzer class.
|
124
113
|
**kwargs : dict
|
125
|
-
|
114
|
+
Keyword arguments used by specific implementations.
|
126
115
|
|
127
116
|
Returns
|
128
117
|
-------
|
129
|
-
|
118
|
+
tuple
|
130
119
|
Single result object combining all input results.
|
131
120
|
|
132
121
|
Notes
|
tme/analyzer/peaks.py
CHANGED
@@ -17,9 +17,9 @@ from skimage.registration._phase_cross_correlation import _upsampled_dft
|
|
17
17
|
from .base import AbstractAnalyzer
|
18
18
|
from ._utils import score_to_cart
|
19
19
|
from ..backends import backend as be
|
20
|
-
from ..matching_utils import split_shape
|
21
20
|
from ..types import BackendArray, NDArray
|
22
21
|
from ..rotations import euler_to_rotationmatrix
|
22
|
+
from ..matching_utils import split_shape, compute_extraction_box
|
23
23
|
|
24
24
|
__all__ = [
|
25
25
|
"PeakCaller",
|
@@ -765,14 +765,14 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
765
765
|
values. If rotations and rotation_mapping is provided, the respective
|
766
766
|
rotation will be applied to the mask, otherwise rotation_matrix is used.
|
767
767
|
"""
|
768
|
-
|
768
|
+
peaks = []
|
769
|
+
box = tuple(self.min_distance for _ in range(scores.ndim))
|
769
770
|
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
mask = be.
|
774
|
-
|
775
|
-
rotated_template = be.zeros(mask.shape, dtype=mask.dtype)
|
771
|
+
scores = be.to_backend_array(scores)
|
772
|
+
if mask is not None:
|
773
|
+
box = mask.shape
|
774
|
+
mask = be.to_backend_array(mask)
|
775
|
+
mask_buffer = be.zeros(mask.shape, dtype=mask.dtype)
|
776
776
|
|
777
777
|
peak_limit = self.num_peaks
|
778
778
|
if min_score is not None:
|
@@ -780,39 +780,45 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
780
780
|
else:
|
781
781
|
min_score = be.min(scores) - 1
|
782
782
|
|
783
|
-
|
784
|
-
|
785
|
-
|
783
|
+
_scores = be.zeros(scores.shape, dtype=scores.dtype)
|
784
|
+
_scores[:] = scores[:]
|
786
785
|
while True:
|
787
|
-
be.argmax(
|
788
|
-
peak
|
789
|
-
indices=be.argmax(scores_copy), shape=scores_copy.shape
|
790
|
-
)
|
791
|
-
if scores_copy[tuple(peak)] < min_score:
|
786
|
+
peak = be.unravel_index(indices=be.argmax(_scores), shape=_scores.shape)
|
787
|
+
if _scores[tuple(peak)] < min_score:
|
792
788
|
break
|
789
|
+
peaks.append(peak)
|
793
790
|
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
rotation_space=rotations,
|
799
|
-
rotation_mapping=rotation_mapping,
|
800
|
-
rotation_matrix=rotation_matrix,
|
791
|
+
score_beg, score_end, tmpl_beg, tmpl_end, _ = compute_extraction_box(
|
792
|
+
centers=be.to_backend_array(peak)[None],
|
793
|
+
extraction_shape=box,
|
794
|
+
original_shape=scores.shape,
|
801
795
|
)
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
mask=mask,
|
808
|
-
rotated_template=rotated_template,
|
796
|
+
score_slice = tuple(
|
797
|
+
slice(int(x), int(y)) for x, y in zip(score_beg[0], score_end[0])
|
798
|
+
)
|
799
|
+
tmpl_slice = tuple(
|
800
|
+
slice(int(x), int(y)) for x, y in zip(tmpl_beg[0], tmpl_end[0])
|
809
801
|
)
|
810
802
|
|
811
|
-
|
803
|
+
score_mask = 0
|
804
|
+
if mask is not None:
|
805
|
+
mask_buffer.fill(0)
|
806
|
+
rmat = self._get_rotation_matrix(
|
807
|
+
peak=peak,
|
808
|
+
rotation_space=rotations,
|
809
|
+
rotation_mapping=rotation_mapping,
|
810
|
+
rotation_matrix=rotation_matrix,
|
811
|
+
)
|
812
|
+
be.rigid_transform(
|
813
|
+
arr=mask, rotation_matrix=rmat, order=1, out=mask_buffer
|
814
|
+
)
|
815
|
+
score_mask = mask_buffer[tmpl_slice] <= 0.1
|
816
|
+
|
817
|
+
_scores[score_slice] = be.multiply(_scores[score_slice], score_mask)
|
818
|
+
if len(peaks) >= peak_limit:
|
812
819
|
break
|
813
820
|
|
814
|
-
|
815
|
-
return peaks, None
|
821
|
+
return be.to_backend_array(peaks), None
|
816
822
|
|
817
823
|
@staticmethod
|
818
824
|
def _get_rotation_matrix(
|
@@ -852,86 +858,6 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
852
858
|
)
|
853
859
|
return rotation
|
854
860
|
|
855
|
-
@staticmethod
|
856
|
-
def _mask_scores_box(
|
857
|
-
scores: BackendArray, peak: BackendArray, mask: BackendArray, **kwargs: Dict
|
858
|
-
) -> None:
|
859
|
-
"""
|
860
|
-
Mask scores in a box around a peak.
|
861
|
-
|
862
|
-
Parameters
|
863
|
-
----------
|
864
|
-
scores : BackendArray
|
865
|
-
Data array of scores.
|
866
|
-
peak : BackendArray
|
867
|
-
Peak coordinates.
|
868
|
-
mask : BackendArray
|
869
|
-
Mask array.
|
870
|
-
"""
|
871
|
-
start = be.maximum(be.subtract(peak, mask.shape), 0)
|
872
|
-
stop = be.minimum(be.add(peak, mask.shape), scores.shape)
|
873
|
-
start, stop = be.astype(start, int), be.astype(stop, int)
|
874
|
-
coords = tuple(slice(*pos) for pos in zip(start, stop))
|
875
|
-
scores[coords] = 0
|
876
|
-
return None
|
877
|
-
|
878
|
-
@staticmethod
|
879
|
-
def _mask_scores_rotate(
|
880
|
-
scores: BackendArray,
|
881
|
-
peak: BackendArray,
|
882
|
-
mask: BackendArray,
|
883
|
-
rotated_template: BackendArray,
|
884
|
-
rotation_matrix: BackendArray,
|
885
|
-
**kwargs: Dict,
|
886
|
-
) -> None:
|
887
|
-
"""
|
888
|
-
Mask scores using mask rotation around a peak.
|
889
|
-
|
890
|
-
Parameters
|
891
|
-
----------
|
892
|
-
scores : BackendArray
|
893
|
-
Data array of scores.
|
894
|
-
peak : BackendArray
|
895
|
-
Peak coordinates.
|
896
|
-
mask : BackendArray
|
897
|
-
Mask array.
|
898
|
-
rotated_template : BackendArray
|
899
|
-
Empty array to write mask rotations to.
|
900
|
-
rotation_matrix : BackendArray
|
901
|
-
Rotation matrix.
|
902
|
-
"""
|
903
|
-
left_pad = be.divide(mask.shape, 2).astype(int)
|
904
|
-
right_pad = be.add(left_pad, be.mod(mask.shape, 2).astype(int))
|
905
|
-
|
906
|
-
score_start = be.subtract(peak, left_pad)
|
907
|
-
score_stop = be.add(peak, right_pad)
|
908
|
-
|
909
|
-
template_start = be.subtract(be.maximum(score_start, 0), score_start)
|
910
|
-
template_stop = be.subtract(score_stop, be.minimum(score_stop, scores.shape))
|
911
|
-
template_stop = be.subtract(mask.shape, template_stop)
|
912
|
-
|
913
|
-
score_start = be.maximum(score_start, 0)
|
914
|
-
score_stop = be.minimum(score_stop, scores.shape)
|
915
|
-
score_start = be.astype(score_start, int)
|
916
|
-
score_stop = be.astype(score_stop, int)
|
917
|
-
|
918
|
-
template_start = be.astype(template_start, int)
|
919
|
-
template_stop = be.astype(template_stop, int)
|
920
|
-
coords_score = tuple(slice(*pos) for pos in zip(score_start, score_stop))
|
921
|
-
coords_template = tuple(
|
922
|
-
slice(*pos) for pos in zip(template_start, template_stop)
|
923
|
-
)
|
924
|
-
|
925
|
-
rotated_template.fill(0)
|
926
|
-
be.rigid_transform(
|
927
|
-
arr=mask, rotation_matrix=rotation_matrix, order=1, out=rotated_template
|
928
|
-
)
|
929
|
-
|
930
|
-
scores[coords_score] = be.multiply(
|
931
|
-
scores[coords_score], (rotated_template[coords_template] <= 0.1)
|
932
|
-
)
|
933
|
-
return None
|
934
|
-
|
935
861
|
|
936
862
|
class PeakCallerScipy(PeakCaller):
|
937
863
|
"""
|
tme/analyzer/proxy.py
CHANGED
tme/backends/_jax_utils.py
CHANGED
@@ -10,16 +10,19 @@ from typing import Tuple
|
|
10
10
|
from functools import partial
|
11
11
|
|
12
12
|
import jax.numpy as jnp
|
13
|
-
from jax import pmap, lax
|
13
|
+
from jax import pmap, lax, vmap
|
14
14
|
|
15
15
|
from ..types import BackendArray
|
16
16
|
from ..backends import backend as be
|
17
17
|
from ..matching_utils import normalize_template as _normalize_template
|
18
18
|
|
19
19
|
|
20
|
+
__all__ = ["scan"]
|
21
|
+
|
22
|
+
|
20
23
|
def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
|
21
24
|
"""
|
22
|
-
Computes :py:meth:`tme.
|
25
|
+
Computes :py:meth:`tme.matching_scores.cc_setup`.
|
23
26
|
"""
|
24
27
|
template_ft = jnp.fft.rfftn(template, s=template.shape)
|
25
28
|
template_ft = template_ft.at[:].multiply(ft_target)
|
@@ -28,18 +31,17 @@ def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
|
|
28
31
|
|
29
32
|
|
30
33
|
def _flc_scoring(
|
31
|
-
template: BackendArray,
|
32
|
-
template_mask: BackendArray,
|
33
34
|
ft_target: BackendArray,
|
34
35
|
ft_target2: BackendArray,
|
36
|
+
template: BackendArray,
|
37
|
+
template_mask: BackendArray,
|
35
38
|
n_observations: BackendArray,
|
36
39
|
eps: float,
|
37
40
|
**kwargs,
|
38
41
|
) -> BackendArray:
|
39
42
|
"""
|
40
|
-
Computes :py:meth:`tme.
|
43
|
+
Computes :py:meth:`tme.matching_scores.flc_scoring`.
|
41
44
|
"""
|
42
|
-
correlation = _correlate(template=template, ft_target=ft_target)
|
43
45
|
inv_denominator = _reciprocal_target_std(
|
44
46
|
ft_target=ft_target,
|
45
47
|
ft_target2=ft_target2,
|
@@ -47,18 +49,17 @@ def _flc_scoring(
|
|
47
49
|
eps=eps,
|
48
50
|
n_observations=n_observations,
|
49
51
|
)
|
50
|
-
|
51
|
-
return correlation
|
52
|
+
return _flcSphere_scoring(ft_target, template, inv_denominator)
|
52
53
|
|
53
54
|
|
54
55
|
def _flcSphere_scoring(
|
55
|
-
template: BackendArray,
|
56
56
|
ft_target: BackendArray,
|
57
|
+
template: BackendArray,
|
57
58
|
inv_denominator: BackendArray,
|
58
59
|
**kwargs,
|
59
60
|
) -> BackendArray:
|
60
61
|
"""
|
61
|
-
Computes :py:meth:`tme.
|
62
|
+
Computes :py:meth:`tme.matching_scores.corr_scoring`.
|
62
63
|
"""
|
63
64
|
correlation = _correlate(template=template, ft_target=ft_target)
|
64
65
|
correlation = correlation.at[:].multiply(inv_denominator)
|
@@ -77,7 +78,7 @@ def _reciprocal_target_std(
|
|
77
78
|
|
78
79
|
See Also
|
79
80
|
--------
|
80
|
-
:py:meth:`tme.
|
81
|
+
:py:meth:`tme.matching_scores.flc_scoring`.
|
81
82
|
"""
|
82
83
|
ft_shape = template_mask.shape
|
83
84
|
ft_template_mask = jnp.fft.rfftn(template_mask, s=ft_shape)
|
@@ -163,12 +164,12 @@ def scan(
|
|
163
164
|
template_rot = _normalize_template(
|
164
165
|
template_rot, template_mask_rot, n_observations
|
165
166
|
)
|
166
|
-
|
167
|
-
|
167
|
+
rot_pad = be.topleft_pad(template_rot, fast_shape)
|
168
|
+
mask_rot_pad = be.topleft_pad(template_mask_rot, fast_shape)
|
168
169
|
|
169
170
|
scores = scoring_func(
|
170
|
-
template=
|
171
|
-
template_mask=
|
171
|
+
template=rot_pad,
|
172
|
+
template_mask=mask_rot_pad,
|
172
173
|
ft_target=ft_target,
|
173
174
|
ft_target2=ft_target2,
|
174
175
|
inv_denominator=inv_denominator,
|
tme/backends/cupy_backend.py
CHANGED
@@ -6,13 +6,10 @@ Copyright (c) 2023 European Molecular Biology Laboratory
|
|
6
6
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
7
|
"""
|
8
8
|
|
9
|
-
import warnings
|
10
9
|
from importlib.util import find_spec
|
11
10
|
from contextlib import contextmanager
|
12
11
|
from typing import Tuple, Callable, List
|
13
12
|
|
14
|
-
import numpy as np
|
15
|
-
|
16
13
|
from .npfftw_backend import NumpyFFTWBackend
|
17
14
|
from ..types import CupyArray, NDArray, shm_type
|
18
15
|
|
@@ -146,15 +143,14 @@ class CupyBackend(NumpyFFTWBackend):
|
|
146
143
|
def rfftn(
|
147
144
|
arr: CupyArray, out: CupyArray = None, s=rfft_shape, axes=fwd_axes
|
148
145
|
) -> CupyArray:
|
149
|
-
return self.rfftn(arr, s=s, axes=fwd_axes)
|
146
|
+
return self.rfftn(arr, s=s, axes=fwd_axes, overwrite_x=True)
|
150
147
|
|
151
148
|
def irfftn(
|
152
149
|
arr: CupyArray, out: CupyArray = None, s=irfft_shape, axes=inv_axes
|
153
150
|
) -> CupyArray:
|
154
|
-
return self.irfftn(arr, s=s, axes=inv_axes)
|
151
|
+
return self.irfftn(arr, s=s, axes=inv_axes, overwrite_x=True)
|
155
152
|
|
156
153
|
PLAN_CACHE[current_device] = [fwd_shape, inv_shape]
|
157
|
-
|
158
154
|
return rfftn, irfftn
|
159
155
|
|
160
156
|
def rfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
|
@@ -239,13 +235,13 @@ class CupyBackend(NumpyFFTWBackend):
|
|
239
235
|
)
|
240
236
|
return None
|
241
237
|
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
238
|
+
if data.ndim == 3 and cache and self.texture_available:
|
239
|
+
# Device memory pool (should) come to rescue performance
|
240
|
+
temp = self.zeros(data.shape, data.dtype)
|
241
|
+
texture = self._get_texture(data, order=order, prefilter=prefilter)
|
242
|
+
texture.affine(transform_m=matrix, profile=False, output=temp)
|
243
|
+
output[out_slice] = temp
|
244
|
+
return None
|
249
245
|
|
250
246
|
self.affine_transform(
|
251
247
|
input=data,
|
tme/backends/jax_backend.py
CHANGED
@@ -7,7 +7,7 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
|
7
7
|
"""
|
8
8
|
|
9
9
|
from functools import wraps
|
10
|
-
from typing import Tuple, List, Callable
|
10
|
+
from typing import Tuple, List, Callable, Dict
|
11
11
|
|
12
12
|
from ..types import BackendArray
|
13
13
|
from .npfftw_backend import NumpyFFTWBackend, shm_type
|
@@ -51,12 +51,6 @@ class JaxBackend(NumpyFFTWBackend):
|
|
51
51
|
)
|
52
52
|
self.scipy = jsp
|
53
53
|
self._create_ufuncs()
|
54
|
-
try:
|
55
|
-
from ._jax_utils import scan as _
|
56
|
-
|
57
|
-
self.scan = self._scan
|
58
|
-
except Exception:
|
59
|
-
pass
|
60
54
|
|
61
55
|
def from_sharedarr(self, arr: BackendArray) -> BackendArray:
|
62
56
|
return arr
|
@@ -189,7 +183,18 @@ class JaxBackend(NumpyFFTWBackend):
|
|
189
183
|
rotations = rotations.at[:].set(self.where(update, rotations, rotation_index))
|
190
184
|
return max_scores, rotations
|
191
185
|
|
192
|
-
def
|
186
|
+
def compute_convolution_shapes(
|
187
|
+
self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
|
188
|
+
) -> Tuple[List[int], List[int], List[int]]:
|
189
|
+
from scipy.fft import next_fast_len
|
190
|
+
|
191
|
+
convolution_shape = [int(x + y - 1) for x, y in zip(arr1_shape, arr2_shape)]
|
192
|
+
fast_shape = [next_fast_len(x, real=True) for x in convolution_shape]
|
193
|
+
fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
|
194
|
+
|
195
|
+
return convolution_shape, fast_shape, fast_ft_shape
|
196
|
+
|
197
|
+
def scan(
|
193
198
|
self,
|
194
199
|
matching_data: type,
|
195
200
|
splits: Tuple[Tuple[slice, slice]],
|
@@ -214,9 +219,9 @@ class JaxBackend(NumpyFFTWBackend):
|
|
214
219
|
conv_shape, fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
|
215
220
|
target_shape=self.to_numpy_array(target_shape),
|
216
221
|
template_shape=self.to_numpy_array(matching_data._template.shape),
|
217
|
-
|
222
|
+
batch_mask=self.to_numpy_array(matching_data._batch_mask),
|
223
|
+
pad_target=pad_target,
|
218
224
|
)
|
219
|
-
|
220
225
|
analyzer_args = {
|
221
226
|
"convolution_mode": convolution_mode,
|
222
227
|
"fourier_shift": shift,
|
@@ -246,19 +251,18 @@ class JaxBackend(NumpyFFTWBackend):
|
|
246
251
|
|
247
252
|
targets, translation_offsets = [], []
|
248
253
|
for target_split, template_split in split_subset:
|
249
|
-
base = matching_data.subset_by_slice(
|
254
|
+
base, translation_offset = matching_data.subset_by_slice(
|
250
255
|
target_slice=target_split,
|
251
256
|
target_pad=target_pad,
|
252
257
|
template_slice=template_split,
|
253
258
|
)
|
254
|
-
translation_offsets.append(
|
259
|
+
translation_offsets.append(translation_offset)
|
255
260
|
targets.append(self.topleft_pad(base._target, fast_shape))
|
256
261
|
|
257
262
|
if create_filter:
|
258
263
|
filter_args = {
|
259
264
|
"data_rfft": self.fft.rfftn(targets[0]),
|
260
265
|
"return_real_fourier": True,
|
261
|
-
"shape_is_real_fourier": False,
|
262
266
|
}
|
263
267
|
|
264
268
|
if create_template_filter:
|
@@ -288,12 +292,11 @@ class JaxBackend(NumpyFFTWBackend):
|
|
288
292
|
|
289
293
|
for index in range(scores.shape[0]):
|
290
294
|
temp = callback_class(
|
291
|
-
shape=scores.shape,
|
295
|
+
shape=scores[index].shape,
|
292
296
|
offset=translation_offsets[index],
|
293
297
|
)
|
294
|
-
state = (scores, rotations, rotation_mapping)
|
298
|
+
state = (scores[index], rotations[index], rotation_mapping)
|
295
299
|
ret.append(temp.result(state, **analyzer_args))
|
296
|
-
|
297
300
|
return ret
|
298
301
|
|
299
302
|
def get_available_memory(self) -> int:
|
tme/backends/npfftw_backend.py
CHANGED
@@ -398,33 +398,33 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
398
398
|
out_mask: NDArray = None,
|
399
399
|
order: int = 3,
|
400
400
|
cache: bool = False,
|
401
|
+
batched: bool = False,
|
401
402
|
) -> Tuple[NDArray, NDArray]:
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
translation=
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
0 if i < (offset - 1) else slice(None) for i in range(arr.ndim)
|
403
|
+
if out is None:
|
404
|
+
out = self.zeros_like(arr)
|
405
|
+
|
406
|
+
# Check whether rotation_matrix is already a rigid transform matrix
|
407
|
+
matrix = rotation_matrix
|
408
|
+
if matrix.shape[-1] == (arr.ndim - int(batched)):
|
409
|
+
center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
|
410
|
+
if not use_geometric_center:
|
411
|
+
center = self.center_of_mass(arr, cutoff=0)
|
412
|
+
|
413
|
+
offset = int(arr.ndim - rotation_matrix.shape[0])
|
414
|
+
center = center[offset:]
|
415
|
+
translation = (
|
416
|
+
self.zeros(center.size) if translation is None else translation
|
417
|
+
)
|
418
|
+
matrix = self._rigid_transform_matrix(
|
419
|
+
rotation_matrix=rotation_matrix,
|
420
|
+
translation=translation,
|
421
|
+
center=center,
|
422
422
|
)
|
423
423
|
|
424
424
|
self._rigid_transform(
|
425
|
-
data=arr
|
425
|
+
data=arr,
|
426
426
|
matrix=matrix,
|
427
|
-
output=out
|
427
|
+
output=out,
|
428
428
|
order=order,
|
429
429
|
prefilter=True,
|
430
430
|
cache=cache,
|
@@ -433,11 +433,13 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
433
433
|
|
434
434
|
# Applying the prefilter leads to artifacts in the mask.
|
435
435
|
if arr_mask is not None:
|
436
|
-
|
436
|
+
if out_mask is None:
|
437
|
+
out_mask = self.zeros_like(arr_mask)
|
438
|
+
|
437
439
|
self._rigid_transform(
|
438
|
-
data=arr_mask
|
440
|
+
data=arr_mask,
|
439
441
|
matrix=matrix,
|
440
|
-
output=out_mask
|
442
|
+
output=out_mask,
|
441
443
|
order=order,
|
442
444
|
prefilter=False,
|
443
445
|
cache=cache,
|
tme/backends/pytorch_backend.py
CHANGED
@@ -306,6 +306,9 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
306
306
|
kwargs["dim"] = kwargs.pop("axes", None)
|
307
307
|
return self._array_backend.fft.irfftn(arr, **kwargs)
|
308
308
|
|
309
|
+
def _rigid_transform_matrix(self, rotation_matrix, *args, **kwargs):
|
310
|
+
return rotation_matrix
|
311
|
+
|
309
312
|
def rigid_transform(
|
310
313
|
self,
|
311
314
|
arr: TorchTensor,
|
@@ -317,6 +320,7 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
317
320
|
out_mask: TorchTensor = None,
|
318
321
|
order: int = 1,
|
319
322
|
cache: bool = False,
|
323
|
+
**kwargs,
|
320
324
|
):
|
321
325
|
_mode_mapping = {0: "nearest", 1: "bilinear", 3: "bicubic"}
|
322
326
|
mode = _mode_mapping.get(order, None)
|
tme/density.py
CHANGED
@@ -1763,12 +1763,13 @@ class Density:
|
|
1763
1763
|
axis=axis,
|
1764
1764
|
)
|
1765
1765
|
|
1766
|
-
|
1766
|
+
mask, mask_ret = np.where(mask), np.where(mask_ret)
|
1767
|
+
|
1768
|
+
arr_ft = np.fft.fftn(self.data)[mask]
|
1767
1769
|
arr_ft *= np.prod(ret_shape) / np.prod(self.shape)
|
1768
1770
|
ret_ft = np.zeros(ret_shape, dtype=arr_ft.dtype)
|
1769
|
-
ret_ft
|
1770
|
-
ret.data = np.real(np.fft.ifftn(ret_ft))
|
1771
|
-
|
1771
|
+
np.add.at(ret_ft, mask_ret, arr_ft)
|
1772
|
+
ret.data = np.real(np.fft.ifftn(ret_ft)).astype(self.data.dtype)
|
1772
1773
|
ret.sampling_rate = new_sampling_rate
|
1773
1774
|
return ret
|
1774
1775
|
|
tme/filters/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
from .ctf import CTF, CTFReconstructed
|
2
2
|
from .compose import Compose, ComposableFilter
|
3
|
-
from .bandpass import
|
3
|
+
from .bandpass import BandPass, BandPassReconstructed
|
4
4
|
from .whitening import LinearWhiteningFilter
|
5
5
|
from .wedge import Wedge, WedgeReconstructed
|
6
|
-
from .reconstruction import ReconstructFromTilt
|
6
|
+
from .reconstruction import ReconstructFromTilt, ShiftFourier
|