pytme 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2.dev0__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.2.dev0.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/match_template.py +213 -196
- {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/postprocess.py +40 -78
- {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/preprocess.py +4 -5
- {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/preprocessor_gui.py +50 -103
- {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/pytme_runner.py +46 -69
- {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/METADATA +2 -1
- pytme-0.3.2.dev0.dist-info/RECORD +136 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +213 -196
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +40 -78
- scripts/preprocess.py +4 -5
- scripts/preprocessor_gui.py +50 -103
- scripts/pytme_runner.py +46 -69
- scripts/refine_matches.py +5 -7
- tests/preprocessing/test_compose.py +31 -30
- tests/preprocessing/test_frequency_filters.py +17 -32
- tests/preprocessing/test_preprocessor.py +0 -19
- tests/preprocessing/test_utils.py +13 -1
- tests/test_analyzer.py +2 -10
- tests/test_backends.py +47 -18
- tests/test_density.py +72 -13
- tests/test_extensions.py +1 -0
- tests/test_matching_cli.py +23 -9
- tests/test_matching_exhaustive.py +5 -5
- tests/test_matching_utils.py +3 -3
- tests/test_rotations.py +13 -23
- tests/test_structure.py +1 -7
- tme/__version__.py +1 -1
- tme/analyzer/aggregation.py +47 -16
- tme/analyzer/base.py +34 -0
- tme/analyzer/peaks.py +26 -13
- tme/analyzer/proxy.py +14 -0
- tme/backends/_jax_utils.py +124 -71
- tme/backends/cupy_backend.py +6 -19
- tme/backends/jax_backend.py +110 -105
- tme/backends/matching_backend.py +0 -17
- tme/backends/mlx_backend.py +0 -29
- tme/backends/npfftw_backend.py +100 -97
- tme/backends/pytorch_backend.py +65 -78
- tme/cli.py +2 -2
- tme/density.py +102 -58
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/_utils.py +52 -24
- tme/filters/bandpass.py +99 -105
- tme/filters/compose.py +133 -39
- tme/filters/ctf.py +51 -102
- tme/filters/reconstruction.py +67 -122
- tme/filters/wedge.py +296 -325
- tme/filters/whitening.py +39 -75
- tme/mask.py +2 -2
- tme/matching_data.py +87 -15
- tme/matching_exhaustive.py +70 -120
- tme/matching_optimization.py +9 -63
- tme/matching_scores.py +261 -100
- tme/matching_utils.py +150 -91
- tme/memory.py +1 -0
- tme/orientations.py +28 -8
- tme/preprocessor.py +0 -239
- tme/rotations.py +102 -70
- tme/structure.py +601 -631
- tme/types.py +1 -0
- pytme-0.3.1.post1.dist-info/RECORD +0 -133
- {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/WHEEL +0 -0
- {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/entry_points.txt +0 -0
- {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/top_level.txt +0 -0
tme/analyzer/aggregation.py
CHANGED
@@ -192,6 +192,15 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
192
192
|
)
|
193
193
|
return scores, rotations, rotation_mapping, ssum
|
194
194
|
|
195
|
+
def correct_background(self, state, mean=0, inv_std=1, **kwargs):
|
196
|
+
scores, rotations, rotation_mapping, ssum = state
|
197
|
+
|
198
|
+
scores = be.subtract(scores, mean, out=scores)
|
199
|
+
scores = be.multiply(scores, inv_std, out=scores)
|
200
|
+
|
201
|
+
scores = be.maximum(scores, self._score_threshold, out=scores)
|
202
|
+
return scores, rotations, rotation_mapping, ssum
|
203
|
+
|
195
204
|
@staticmethod
|
196
205
|
def _invert_rmap(rotation_mapping: dict) -> dict:
|
197
206
|
"""
|
@@ -201,7 +210,12 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
201
210
|
new_map, ndim = {}, None
|
202
211
|
for k, v in rotation_mapping.items():
|
203
212
|
nbytes = be.datatype_bytes(be._float_dtype)
|
204
|
-
|
213
|
+
if nbytes == 8:
|
214
|
+
dtype = np.float64
|
215
|
+
elif nbytes == 4:
|
216
|
+
dtype = np.float32
|
217
|
+
else:
|
218
|
+
np.float16
|
205
219
|
rmat = np.frombuffer(k, dtype=dtype)
|
206
220
|
if ndim is None:
|
207
221
|
ndim = int(np.sqrt(rmat.size))
|
@@ -451,7 +465,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
451
465
|
Maximum accepted rotational deviation in degrees.
|
452
466
|
positions : BackendArray
|
453
467
|
Array of shape (n, d) with n seed point translations.
|
454
|
-
|
468
|
+
rotations : BackendArray
|
455
469
|
Array of shape (n, d, d) with n seed point rotation matrices.
|
456
470
|
reference : BackendArray
|
457
471
|
Reference orientation of the template, wlog defaults to (0,0,1).
|
@@ -489,6 +503,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
489
503
|
be.reshape(be.to_backend_array(reference), (-1,)), be._float_dtype
|
490
504
|
)
|
491
505
|
positions = be.astype(be.to_backend_array(positions), be._int_dtype)
|
506
|
+
rotations = be.astype(be.to_backend_array(rotations), be._float_dtype)
|
492
507
|
|
493
508
|
ndim = positions.shape[1]
|
494
509
|
rotate_mask = len(set(acceptance_radius)) != 1
|
@@ -515,7 +530,13 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
515
530
|
)
|
516
531
|
|
517
532
|
self._positions = positions[valid_positions]
|
518
|
-
rotations =
|
533
|
+
rotations = rotations[valid_positions]
|
534
|
+
|
535
|
+
# Convert to pull matrix to remain consistent with rotation convention
|
536
|
+
rotations = be.concatenate(
|
537
|
+
[rotations[i].T[None] for i in range(rotations.shape[0])]
|
538
|
+
)
|
539
|
+
|
519
540
|
ex = be.astype(be.to_backend_array((1, 0, 0)), be._float_dtype)
|
520
541
|
ey = be.astype(be.to_backend_array((0, 1, 0)), be._float_dtype)
|
521
542
|
ez = be.astype(be.to_backend_array((0, 0, 1)), be._float_dtype)
|
@@ -524,6 +545,15 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
524
545
|
self._normals_y = (rotations @ ey[..., None])[..., 0]
|
525
546
|
self._normals_z = (rotations @ ez[..., None])[..., 0]
|
526
547
|
|
548
|
+
# All scores will be rejected in this case. We should think about a
|
549
|
+
# unified interface for checking analyzer validity to skip such runs
|
550
|
+
if self._positions.shape[0] == 0:
|
551
|
+
|
552
|
+
def _get_score_mask(*args, **kwargs):
|
553
|
+
return 0
|
554
|
+
|
555
|
+
self._get_score_mask = _get_score_mask
|
556
|
+
|
527
557
|
# Periodic wrapping could be avoided by padding the target
|
528
558
|
shape = be.to_backend_array(self._shape)
|
529
559
|
starts = be.subtract(self._positions, extend)
|
@@ -539,9 +569,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
539
569
|
self._mask_shape = tuple(1 if i != 0 else -1 for i in range(1 + ndim))
|
540
570
|
|
541
571
|
if rotate_mask:
|
542
|
-
self._score_mask =
|
543
|
-
(rotations.shape[0], *self._score_mask.shape), dtype=be._float_dtype
|
544
|
-
)
|
572
|
+
self._score_mask = []
|
545
573
|
for i in range(rotations.shape[0]):
|
546
574
|
mask = create_mask(
|
547
575
|
mask_type="ellipse",
|
@@ -550,9 +578,10 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
550
578
|
center=tuple(extend for _ in range(ndim)),
|
551
579
|
orientation=be.to_numpy_array(rotations[i]),
|
552
580
|
)
|
553
|
-
self._score_mask
|
554
|
-
be.to_backend_array(mask), be._float_dtype
|
581
|
+
self._score_mask.append(
|
582
|
+
be.astype(be.to_backend_array(mask), be._float_dtype)[None]
|
555
583
|
)
|
584
|
+
self._score_mask = be.concatenate(self._score_mask)
|
556
585
|
|
557
586
|
def __call__(
|
558
587
|
self,
|
@@ -573,7 +602,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
573
602
|
"""
|
574
603
|
Determine whether the angle between projection of reference w.r.t to
|
575
604
|
a given rotation matrix and a set of rotations fall within the set
|
576
|
-
|
605
|
+
cone angle cutoff.
|
577
606
|
|
578
607
|
Parameters
|
579
608
|
----------
|
@@ -585,7 +614,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
585
614
|
BackerndArray
|
586
615
|
Boolean mask of shape (n, )
|
587
616
|
"""
|
588
|
-
template_rot = rotation_matrix @ self._reference
|
617
|
+
template_rot = rotation_matrix.T @ self._reference
|
589
618
|
|
590
619
|
x = be.sum(be.multiply(self._normals_x, template_rot), axis=1)
|
591
620
|
y = be.sum(be.multiply(self._normals_y, template_rot), axis=1)
|
@@ -596,10 +625,9 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
596
625
|
def _get_score_mask(self, mask: BackendArray, scores: BackendArray, **kwargs):
|
597
626
|
score_mask = be.zeros(scores.shape, scores.dtype)
|
598
627
|
|
599
|
-
|
600
|
-
|
628
|
+
# The indexing could be improved to avoid expanding the mask to
|
629
|
+
# the number of seed points
|
601
630
|
mask = be.reshape(mask, self._mask_shape)
|
602
|
-
|
603
631
|
score_mask = be.addat(score_mask, self._index_grid, self._score_mask * mask)
|
604
632
|
return score_mask > 0
|
605
633
|
|
@@ -663,13 +691,16 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
663
691
|
rotation_index = len(rotation_mapping)
|
664
692
|
if self._inversion_mapping:
|
665
693
|
rotation_mapping[rotation_index] = rotation_matrix
|
694
|
+
elif self._jax_mode:
|
695
|
+
rotation_index = kwargs.get("rotation_index", 0)
|
666
696
|
else:
|
667
697
|
rotation = be.tobytes(rotation_matrix)
|
668
698
|
rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
|
669
|
-
max_score = be.max(scores, axis=self._aggregate_axis)
|
670
699
|
|
671
|
-
|
672
|
-
|
700
|
+
scores = be.max(scores, axis=self._aggregate_axis)
|
701
|
+
scores = be.maximum(scores, prev_scores[rotation_index])
|
702
|
+
prev_scores = be.at(prev_scores, rotation_index, scores)
|
703
|
+
|
673
704
|
return prev_scores, rotations, rotation_mapping
|
674
705
|
|
675
706
|
@classmethod
|
tme/analyzer/base.py
CHANGED
@@ -73,6 +73,40 @@ class AbstractAnalyzer(ABC):
|
|
73
73
|
Updated analyzer state incorporating the new data.
|
74
74
|
"""
|
75
75
|
|
76
|
+
@abstractmethod
|
77
|
+
def correct_background(self, state, mean=0, inv_std=1, **kwargs):
|
78
|
+
"""
|
79
|
+
Applies flat-fielding correction to scores f as
|
80
|
+
|
81
|
+
.. math::
|
82
|
+
|
83
|
+
f' = (f - \\text{mean}) \\cdot \\text{inv_std},
|
84
|
+
|
85
|
+
transforming raw correlations to SNR-like scores.
|
86
|
+
|
87
|
+
Parameters
|
88
|
+
----------
|
89
|
+
state : tuple
|
90
|
+
Current analyzer state as returned :py:meth:`AbstractAnalyzer.init_state`
|
91
|
+
or previous invocations of :py:meth:`AbstractAnalyzer.__call__`.
|
92
|
+
mean : BackendArray, optional
|
93
|
+
Background mean (or equivalent), defaults to 0.
|
94
|
+
inv_std : BackendArray, optional
|
95
|
+
Reciprocal background standard deviation (or equivalent), defaults to 1.
|
96
|
+
|
97
|
+
Notes
|
98
|
+
-----
|
99
|
+
This method should be called after all rotations have been processed
|
100
|
+
but before calling :py:meth:`result`. The correction helps distinguish genuine
|
101
|
+
template matches from systematic background artifacts that may arise from
|
102
|
+
template edges, interpolation artifacts, or structured noise in the target.
|
103
|
+
|
104
|
+
Returns
|
105
|
+
-------
|
106
|
+
tuple
|
107
|
+
Updated analyzer state incorporating the new data.
|
108
|
+
"""
|
109
|
+
|
76
110
|
@abstractmethod
|
77
111
|
def result(self, state: Tuple, **kwargs) -> Tuple:
|
78
112
|
"""
|
tme/analyzer/peaks.py
CHANGED
@@ -18,7 +18,6 @@ from .base import AbstractAnalyzer
|
|
18
18
|
from ._utils import score_to_cart
|
19
19
|
from ..backends import backend as be
|
20
20
|
from ..types import BackendArray, NDArray
|
21
|
-
from ..rotations import euler_to_rotationmatrix
|
22
21
|
from ..matching_utils import split_shape, compute_extraction_box
|
23
22
|
|
24
23
|
__all__ = [
|
@@ -182,6 +181,7 @@ class PeakCaller(AbstractAnalyzer):
|
|
182
181
|
min_score: float = None,
|
183
182
|
max_score: float = None,
|
184
183
|
batch_dims: Tuple[int] = None,
|
184
|
+
projection_dims: Tuple[int] = None,
|
185
185
|
shm_handler: object = None,
|
186
186
|
**kwargs,
|
187
187
|
):
|
@@ -197,9 +197,13 @@ class PeakCaller(AbstractAnalyzer):
|
|
197
197
|
self.min_distance = int(min_distance)
|
198
198
|
self.min_boundary_distance = int(min_boundary_distance)
|
199
199
|
|
200
|
-
self.batch_dims =
|
200
|
+
self.batch_dims = ()
|
201
201
|
if batch_dims is not None:
|
202
|
-
self.batch_dims = tuple(int(x) for x in
|
202
|
+
self.batch_dims = tuple(int(x) for x in batch_dims)
|
203
|
+
|
204
|
+
self.projection_dims = ()
|
205
|
+
if projection_dims is not None:
|
206
|
+
self.projection_dims = tuple(int(x) for x in projection_dims)
|
203
207
|
|
204
208
|
self.min_score, self.max_score = min_score, max_score
|
205
209
|
|
@@ -231,7 +235,7 @@ class PeakCaller(AbstractAnalyzer):
|
|
231
235
|
|
232
236
|
rdim = len(self.shape)
|
233
237
|
if self.batch_dims:
|
234
|
-
rdim
|
238
|
+
rdim = rdim - len(self.batch_dims) + len(self.projection_dims)
|
235
239
|
|
236
240
|
rotations = be.full(
|
237
241
|
(self.num_peaks, rdim, rdim), fill_value=0, dtype=be._float_dtype
|
@@ -388,6 +392,20 @@ class PeakCaller(AbstractAnalyzer):
|
|
388
392
|
|
389
393
|
return state
|
390
394
|
|
395
|
+
def correct_background(self, state, mean, inv_std=1, **kwargs):
|
396
|
+
arr_type = type(be.zeros((1,), be._float_dtype))
|
397
|
+
translations, rotations, scores, details = state
|
398
|
+
|
399
|
+
if isinstance(mean, arr_type):
|
400
|
+
mean = mean[tuple(be.astype(translations.T, int))]
|
401
|
+
scores = be.subtract(scores, mean, out=scores)
|
402
|
+
|
403
|
+
if isinstance(inv_std, arr_type):
|
404
|
+
inv_std = inv_std[tuple(be.astype(translations.T, int))]
|
405
|
+
scores = be.multiply(scores, inv_std, out=scores)
|
406
|
+
|
407
|
+
return translations, rotations, scores, details
|
408
|
+
|
391
409
|
@classmethod
|
392
410
|
def merge(cls, results=List[Tuple], **kwargs) -> Tuple:
|
393
411
|
"""
|
@@ -778,6 +796,9 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
778
796
|
mask = be.to_backend_array(mask)
|
779
797
|
mask_buffer = be.zeros(mask.shape, dtype=mask.dtype)
|
780
798
|
|
799
|
+
if min_score is None:
|
800
|
+
min_score = self.min_score
|
801
|
+
|
781
802
|
if min_score is None:
|
782
803
|
min_score = be.min(scores) - 1
|
783
804
|
|
@@ -849,15 +870,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
849
870
|
"""
|
850
871
|
if rotation_space is None or rotation_mapping is None:
|
851
872
|
return rotation_matrix
|
852
|
-
|
853
|
-
rotation = rotation_mapping[rotation_space[tuple(peak)]]
|
854
|
-
|
855
|
-
# Old versions of rotation mapping contained Euler angles
|
856
|
-
if rotation.ndim != 2:
|
857
|
-
rotation = be.to_backend_array(
|
858
|
-
euler_to_rotationmatrix(be.to_numpy_array(rotation))
|
859
|
-
)
|
860
|
-
return rotation
|
873
|
+
return rotation_mapping[rotation_space[tuple(peak)]]
|
861
874
|
|
862
875
|
|
863
876
|
class PeakCallerScipy(PeakCaller):
|
tme/analyzer/proxy.py
CHANGED
@@ -85,6 +85,16 @@ class StatelessSharedAnalyzerProxy:
|
|
85
85
|
final_state = tuple(self._shared_to_object(x) for x in final_state)
|
86
86
|
return self._analyzer.result(final_state, **kwargs)
|
87
87
|
|
88
|
+
def correct_background(self, state, *args, **kwargs):
|
89
|
+
if self._shared:
|
90
|
+
# Copy to not correct the internal score array across processes
|
91
|
+
backend_arr = type(be.zeros((1), dtype=be._float_dtype))
|
92
|
+
state = tuple(self._shared_to_object(x) for x in state)
|
93
|
+
state = tuple(
|
94
|
+
be.copy(x) if isinstance(x, backend_arr) else x for x in state
|
95
|
+
)
|
96
|
+
return self._analyzer.correct_background(state, *args, **kwargs)
|
97
|
+
|
88
98
|
def merge(self, *args, **kwargs):
|
89
99
|
return self._analyzer.merge(*args, **kwargs)
|
90
100
|
|
@@ -121,3 +131,7 @@ class SharedAnalyzerProxy(StatelessSharedAnalyzerProxy):
|
|
121
131
|
def result(self, **kwargs):
|
122
132
|
"""Extract final result"""
|
123
133
|
return super().result(self._state, **kwargs)
|
134
|
+
|
135
|
+
def correct_background(self, *args, **kwargs):
|
136
|
+
# We always assign to state as this operation can not be shared
|
137
|
+
self._state = super().correct_background(self._state, *args, **kwargs)
|
tme/backends/_jax_utils.py
CHANGED
@@ -10,14 +10,14 @@ from typing import Tuple
|
|
10
10
|
from functools import partial
|
11
11
|
|
12
12
|
import jax.numpy as jnp
|
13
|
-
from jax import pmap, lax,
|
13
|
+
from jax import pmap, lax, jit
|
14
14
|
|
15
15
|
from ..types import BackendArray
|
16
16
|
from ..backends import backend as be
|
17
|
-
from ..matching_utils import
|
17
|
+
from ..matching_utils import standardize, to_padded
|
18
18
|
|
19
19
|
|
20
|
-
__all__ = ["scan"]
|
20
|
+
__all__ = ["scan", "setup_scan"]
|
21
21
|
|
22
22
|
|
23
23
|
def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
|
@@ -62,15 +62,14 @@ def _flcSphere_scoring(
|
|
62
62
|
Computes :py:meth:`tme.matching_scores.corr_scoring`.
|
63
63
|
"""
|
64
64
|
correlation = _correlate(template=template, ft_target=ft_target)
|
65
|
-
|
66
|
-
return correlation
|
65
|
+
return correlation.at[:].multiply(inv_denominator)
|
67
66
|
|
68
67
|
|
69
68
|
def _reciprocal_target_std(
|
70
69
|
ft_target: BackendArray,
|
71
70
|
ft_target2: BackendArray,
|
72
71
|
template_mask: BackendArray,
|
73
|
-
|
72
|
+
n_obs: float,
|
74
73
|
eps: float,
|
75
74
|
) -> BackendArray:
|
76
75
|
"""
|
@@ -80,16 +79,16 @@ def _reciprocal_target_std(
|
|
80
79
|
--------
|
81
80
|
:py:meth:`tme.matching_scores.flc_scoring`.
|
82
81
|
"""
|
83
|
-
|
84
|
-
ft_template_mask = jnp.fft.rfftn(template_mask, s=
|
82
|
+
shape = template_mask.shape
|
83
|
+
ft_template_mask = jnp.fft.rfftn(template_mask, s=shape)
|
85
84
|
|
86
85
|
# E(X^2)- E(X)^2
|
87
|
-
exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask, s=
|
88
|
-
exp_sq = exp_sq.at[:].divide(
|
86
|
+
exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask, s=shape)
|
87
|
+
exp_sq = exp_sq.at[:].divide(n_obs)
|
89
88
|
|
90
89
|
ft_template_mask = ft_template_mask.at[:].multiply(ft_target)
|
91
|
-
sq_exp = jnp.fft.irfftn(ft_template_mask, s=
|
92
|
-
sq_exp = sq_exp.at[:].divide(
|
90
|
+
sq_exp = jnp.fft.irfftn(ft_template_mask, s=shape)
|
91
|
+
sq_exp = sq_exp.at[:].divide(n_obs)
|
93
92
|
sq_exp = sq_exp.at[:].power(2)
|
94
93
|
|
95
94
|
exp_sq = exp_sq.at[:].add(-sq_exp)
|
@@ -97,7 +96,7 @@ def _reciprocal_target_std(
|
|
97
96
|
exp_sq = exp_sq.at[:].power(0.5)
|
98
97
|
|
99
98
|
exp_sq = exp_sq.at[:].set(
|
100
|
-
jnp.where(exp_sq <= eps, 0, jnp.reciprocal(exp_sq *
|
99
|
+
jnp.where(exp_sq <= eps, 0, jnp.reciprocal(exp_sq * n_obs))
|
101
100
|
)
|
102
101
|
return exp_sq
|
103
102
|
|
@@ -108,20 +107,50 @@ def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> Backen
|
|
108
107
|
return arr.at[:].set(jnp.fft.irfftn(arr_ft, s=arr.shape))
|
109
108
|
|
110
109
|
|
111
|
-
def
|
112
|
-
|
110
|
+
def setup_scan(analyzer_kwargs, analyzer, fast_shape, rotate_mask, match_projection):
|
111
|
+
"""Create separate scan function with initialized analyzer for each device"""
|
112
|
+
device_scans = [
|
113
|
+
partial(
|
114
|
+
scan,
|
115
|
+
fast_shape=fast_shape,
|
116
|
+
rotate_mask=rotate_mask,
|
117
|
+
analyzer=analyzer(**device_config),
|
118
|
+
)
|
119
|
+
for device_config in analyzer_kwargs
|
120
|
+
]
|
113
121
|
|
122
|
+
@partial(
|
123
|
+
pmap,
|
124
|
+
in_axes=(0,) + (None,) * 7,
|
125
|
+
axis_name="batch",
|
126
|
+
)
|
127
|
+
def scan_combined(
|
128
|
+
target,
|
129
|
+
template,
|
130
|
+
template_mask,
|
131
|
+
rotations,
|
132
|
+
template_filter,
|
133
|
+
target_filter,
|
134
|
+
score_mask,
|
135
|
+
background_template,
|
136
|
+
):
|
137
|
+
return lax.switch(
|
138
|
+
lax.axis_index("batch"),
|
139
|
+
device_scans,
|
140
|
+
target,
|
141
|
+
template,
|
142
|
+
template_mask,
|
143
|
+
rotations,
|
144
|
+
template_filter,
|
145
|
+
target_filter,
|
146
|
+
score_mask,
|
147
|
+
background_template,
|
148
|
+
)
|
114
149
|
|
115
|
-
|
116
|
-
return arr.at[:].multiply(mask)
|
150
|
+
return scan_combined
|
117
151
|
|
118
152
|
|
119
|
-
@partial(
|
120
|
-
pmap,
|
121
|
-
in_axes=(0,) + (None,) * 7,
|
122
|
-
static_broadcasted_argnums=[7, 8, 9, 10],
|
123
|
-
axis_name="batch",
|
124
|
-
)
|
153
|
+
@partial(jit, static_argnums=(8, 9, 10))
|
125
154
|
def scan(
|
126
155
|
target: BackendArray,
|
127
156
|
template: BackendArray,
|
@@ -130,74 +159,98 @@ def scan(
|
|
130
159
|
template_filter: BackendArray,
|
131
160
|
target_filter: BackendArray,
|
132
161
|
score_mask: BackendArray,
|
162
|
+
background_template: BackendArray,
|
133
163
|
fast_shape: Tuple[int],
|
134
164
|
rotate_mask: bool,
|
135
|
-
|
136
|
-
|
137
|
-
) -> Tuple[BackendArray, BackendArray]:
|
165
|
+
analyzer: object,
|
166
|
+
) -> Tuple:
|
138
167
|
eps = jnp.finfo(template.dtype).resolution
|
139
168
|
|
140
|
-
|
141
|
-
lax.axis_index("batch"),
|
142
|
-
[lambda: analyzer_kwargs[i] for i in range(len(analyzer_kwargs))],
|
143
|
-
)
|
144
|
-
analyzer = analyzer_class(**be._tuple_to_dict(kwargs))
|
145
|
-
|
146
|
-
if hasattr(target_filter, "shape"):
|
169
|
+
if target_filter.shape != ():
|
147
170
|
target = _apply_fourier_filter(target, target_filter)
|
148
171
|
|
149
172
|
ft_target = jnp.fft.rfftn(target, s=fast_shape)
|
150
173
|
ft_target2 = jnp.fft.rfftn(jnp.square(target), s=fast_shape)
|
151
|
-
|
174
|
+
_n_obs, _inv_denominator, target = None, None, None
|
175
|
+
|
176
|
+
unpadded_slice = tuple(slice(0, x) for x in template.shape)
|
177
|
+
rot_buffer, mask_rot_buffer = jnp.zeros(fast_shape), jnp.zeros(fast_shape)
|
152
178
|
if not rotate_mask:
|
153
|
-
|
154
|
-
|
179
|
+
_n_obs = jnp.sum(template_mask)
|
180
|
+
_inv_denominator = _reciprocal_target_std(
|
155
181
|
ft_target=ft_target,
|
156
182
|
ft_target2=ft_target2,
|
157
|
-
template_mask=
|
183
|
+
template_mask=to_padded(mask_rot_buffer, template_mask, unpadded_slice),
|
158
184
|
eps=eps,
|
159
|
-
|
185
|
+
n_obs=_n_obs,
|
160
186
|
)
|
161
|
-
ft_target2
|
187
|
+
ft_target2 = None
|
162
188
|
|
163
|
-
|
164
|
-
|
165
|
-
|
189
|
+
mask_scores = score_mask.shape != ()
|
190
|
+
filter_template = template_filter.shape != ()
|
191
|
+
bg_correction = background_template.shape != ()
|
192
|
+
bg_scores = jnp.zeros(fast_shape) if bg_correction else 0
|
166
193
|
|
167
|
-
|
168
|
-
|
169
|
-
|
194
|
+
_template_mask_rot = template_mask
|
195
|
+
template_indices = be._index_grid(template.shape)
|
196
|
+
center = be.divide(be.to_backend_array(template.shape) - 1, 2)
|
170
197
|
|
171
198
|
def _sample_transform(ret, rotation_matrix):
|
172
|
-
|
173
|
-
|
174
|
-
arr=template,
|
175
|
-
arr_mask=template_mask,
|
176
|
-
rotation_matrix=rotation_matrix,
|
177
|
-
order=1, # thats all we get for now
|
199
|
+
matrix = be._build_transform_matrix(
|
200
|
+
rotation_matrix=rotation_matrix, center=center
|
178
201
|
)
|
202
|
+
indices = be._transform_indices(template_indices, matrix)
|
203
|
+
|
204
|
+
template_rot = be._interpolate(template, indices, order=1)
|
205
|
+
n_obs, template_mask_rot = _n_obs, _template_mask_rot
|
206
|
+
if rotate_mask:
|
207
|
+
template_mask_rot = be._interpolate(template_mask, indices, order=1)
|
208
|
+
n_obs = jnp.sum(template_mask_rot)
|
209
|
+
|
210
|
+
if filter_template:
|
211
|
+
template_rot = _apply_fourier_filter(template_rot, template_filter)
|
212
|
+
template_rot = standardize(template_rot, template_mask_rot, n_obs)
|
213
|
+
|
214
|
+
rot_pad = to_padded(rot_buffer, template_rot, unpadded_slice)
|
215
|
+
|
216
|
+
inv_denominator = _inv_denominator
|
217
|
+
if rotate_mask:
|
218
|
+
mask_rot_pad = to_padded(mask_rot_buffer, template_mask_rot, unpadded_slice)
|
219
|
+
inv_denominator = _reciprocal_target_std(
|
220
|
+
ft_target=ft_target,
|
221
|
+
ft_target2=ft_target2,
|
222
|
+
template_mask=mask_rot_pad,
|
223
|
+
n_obs=n_obs,
|
224
|
+
eps=eps,
|
225
|
+
)
|
226
|
+
|
227
|
+
scores = _flcSphere_scoring(ft_target, rot_pad, inv_denominator)
|
228
|
+
if mask_scores:
|
229
|
+
scores = scores.at[:].multiply(score_mask)
|
230
|
+
|
231
|
+
state, bg_scores, index = ret
|
232
|
+
state = analyzer(state, scores, rotation_matrix, rotation_index=index)
|
179
233
|
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
rot_pad = be.topleft_pad(template_rot, fast_shape)
|
186
|
-
mask_rot_pad = be.topleft_pad(template_mask_rot, fast_shape)
|
234
|
+
if bg_correction:
|
235
|
+
template_rot = be._interpolate(background_template, indices, order=1)
|
236
|
+
if filter_template:
|
237
|
+
template_rot = _apply_fourier_filter(template_rot, template_filter)
|
238
|
+
template_rot = standardize(template_rot, template_mask_rot, n_obs)
|
187
239
|
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
ft_target=ft_target,
|
192
|
-
ft_target2=ft_target2,
|
193
|
-
inv_denominator=inv_denominator,
|
194
|
-
n_observations=n_observations,
|
195
|
-
eps=eps,
|
196
|
-
)
|
197
|
-
scores = _score_mask_func(scores, score_mask)
|
240
|
+
rot_pad = to_padded(rot_buffer, template_rot, unpadded_slice)
|
241
|
+
scores = _flcSphere_scoring(ft_target, rot_pad, inv_denominator)
|
242
|
+
bg_scores = jnp.maximum(bg_scores, scores)
|
198
243
|
|
199
|
-
|
200
|
-
|
244
|
+
return (state, bg_scores, index + 1), None
|
245
|
+
|
246
|
+
(state, bg_scores, _), _ = lax.scan(
|
247
|
+
_sample_transform, (analyzer.init_state(), bg_scores, 0), rotations
|
248
|
+
)
|
249
|
+
|
250
|
+
if bg_correction:
|
251
|
+
if mask_scores:
|
252
|
+
bg_scores = bg_scores.at[:].multiply(score_mask)
|
253
|
+
bg_scores = bg_scores.at[:].add(-be.mean(bg_scores))
|
254
|
+
state = analyzer.correct_background(state, bg_scores)
|
201
255
|
|
202
|
-
(state, _), _ = lax.scan(_sample_transform, (analyzer.init_state(), 0), rotations)
|
203
256
|
return state
|
tme/backends/cupy_backend.py
CHANGED
@@ -33,7 +33,6 @@ class CupyBackend(NumpyFFTWBackend):
|
|
33
33
|
import cupy as cp
|
34
34
|
import cupyx.scipy.fft as cufft
|
35
35
|
from cupyx.scipy.ndimage import affine_transform, maximum_filter
|
36
|
-
from ._cupy_utils import affine_transform_batch
|
37
36
|
|
38
37
|
float_dtype = cp.float32 if float_dtype is None else float_dtype
|
39
38
|
complex_dtype = cp.complex64 if complex_dtype is None else complex_dtype
|
@@ -51,7 +50,6 @@ class CupyBackend(NumpyFFTWBackend):
|
|
51
50
|
self._cufft = cufft
|
52
51
|
self.maximum_filter = maximum_filter
|
53
52
|
self.affine_transform = affine_transform
|
54
|
-
self.affine_transform_batch = affine_transform_batch
|
55
53
|
|
56
54
|
itype = f"int{self.datatype_bytes(int_dtype) * 8}"
|
57
55
|
ftype = f"float{self.datatype_bytes(float_dtype) * 8}"
|
@@ -157,8 +155,8 @@ class CupyBackend(NumpyFFTWBackend):
|
|
157
155
|
|
158
156
|
from voltools import StaticVolume
|
159
157
|
|
160
|
-
# Only keep template and
|
161
|
-
if len(TEXTURE_CACHE) >=
|
158
|
+
# Only keep template, mask and noise template in cache
|
159
|
+
if len(TEXTURE_CACHE) >= 3:
|
162
160
|
TEXTURE_CACHE.clear()
|
163
161
|
|
164
162
|
interpolation = "filt_bspline"
|
@@ -174,7 +172,7 @@ class CupyBackend(NumpyFFTWBackend):
|
|
174
172
|
|
175
173
|
return TEXTURE_CACHE[key]
|
176
174
|
|
177
|
-
def
|
175
|
+
def _transform(
|
178
176
|
self,
|
179
177
|
data: CupyArray,
|
180
178
|
matrix: CupyArray,
|
@@ -182,21 +180,10 @@ class CupyBackend(NumpyFFTWBackend):
|
|
182
180
|
prefilter: bool,
|
183
181
|
order: int,
|
184
182
|
cache: bool = False,
|
185
|
-
|
186
|
-
) -> None:
|
183
|
+
) -> CupyArray:
|
187
184
|
out_slice = tuple(slice(0, stop) for stop in data.shape)
|
188
|
-
if batched:
|
189
|
-
self.affine_transform_batch(
|
190
|
-
input=data,
|
191
|
-
matrix=matrix,
|
192
|
-
mode="constant",
|
193
|
-
output=output[out_slice],
|
194
|
-
order=order,
|
195
|
-
prefilter=prefilter,
|
196
|
-
)
|
197
|
-
return None
|
198
185
|
|
199
|
-
if data.ndim == 3 and cache and self.texture_available
|
186
|
+
if data.ndim == 3 and cache and self.texture_available:
|
200
187
|
# Device memory pool (should) come to rescue performance
|
201
188
|
temp = self.zeros(data.shape, data.dtype)
|
202
189
|
texture = self._get_texture(data, order=order, prefilter=prefilter)
|
@@ -204,7 +191,7 @@ class CupyBackend(NumpyFFTWBackend):
|
|
204
191
|
output[out_slice] = temp
|
205
192
|
return None
|
206
193
|
|
207
|
-
self.affine_transform(
|
194
|
+
return self.affine_transform(
|
208
195
|
input=data,
|
209
196
|
matrix=matrix,
|
210
197
|
mode="constant",
|