pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/match_template.py +147 -93
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/postprocess.py +67 -26
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +175 -85
- pytme-0.2.2.dist-info/METADATA +91 -0
- pytme-0.2.2.dist-info/RECORD +74 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
- scripts/extract_candidates.py +20 -13
- scripts/match_template.py +147 -93
- scripts/match_template_filters.py +154 -95
- scripts/postprocess.py +67 -26
- scripts/preprocessor_gui.py +175 -85
- scripts/refine_matches.py +265 -61
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +451 -809
- tme/backends/__init__.py +40 -11
- tme/backends/_jax_utils.py +185 -0
- tme/backends/cupy_backend.py +111 -223
- tme/backends/jax_backend.py +214 -150
- tme/backends/matching_backend.py +445 -384
- tme/backends/mlx_backend.py +32 -59
- tme/backends/npfftw_backend.py +239 -507
- tme/backends/pytorch_backend.py +21 -145
- tme/density.py +233 -363
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +322 -285
- tme/matching_exhaustive.py +172 -1493
- tme/matching_optimization.py +143 -106
- tme/matching_scores.py +884 -0
- tme/matching_utils.py +280 -386
- tme/memory.py +377 -0
- tme/orientations.py +52 -12
- tme/parser.py +3 -4
- tme/preprocessing/_utils.py +61 -32
- tme/preprocessing/compose.py +7 -3
- tme/preprocessing/frequency_filters.py +49 -39
- tme/preprocessing/tilt_series.py +34 -40
- tme/preprocessor.py +560 -526
- tme/structure.py +491 -188
- tme/types.py +5 -3
- pytme-0.2.1.dist-info/METADATA +0 -73
- pytme-0.2.1.dist-info/RECORD +0 -73
- tme/helpers.py +0 -881
- tme/matching_constrained.py +0 -195
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
tme/backends/__init__.py
CHANGED
@@ -4,14 +4,15 @@
|
|
4
4
|
|
5
5
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
6
|
"""
|
7
|
-
|
8
7
|
from typing import Dict, List
|
8
|
+
from importlib.util import find_spec
|
9
9
|
|
10
10
|
from .matching_backend import MatchingBackend
|
11
11
|
from .npfftw_backend import NumpyFFTWBackend
|
12
12
|
from .pytorch_backend import PytorchBackend
|
13
13
|
from .cupy_backend import CupyBackend
|
14
14
|
from .mlx_backend import MLXBackend
|
15
|
+
from .jax_backend import JaxBackend
|
15
16
|
|
16
17
|
|
17
18
|
class BackendManager:
|
@@ -43,24 +44,26 @@ class BackendManager:
|
|
43
44
|
|
44
45
|
>>> backend.change_backend("pytorch")
|
45
46
|
>>> backend.multiply(arr1, arr2)
|
46
|
-
# This will use the
|
47
|
+
# This will use the pytorchs multiply method
|
48
|
+
|
49
|
+
>>> backend.available_backends()
|
50
|
+
# Backends available on your system
|
47
51
|
|
48
52
|
Notes
|
49
53
|
-----
|
50
|
-
|
51
|
-
between backends, use the `change_backend` method. Note that the backend
|
52
|
-
has to be reinitialzed when using fork-based parallelism.
|
54
|
+
The backend has to be reinitialzed when using fork-based parallelism.
|
53
55
|
"""
|
54
56
|
|
55
57
|
def __init__(self):
|
56
58
|
self._BACKEND_REGISTRY = {
|
57
|
-
"
|
59
|
+
"numpyfftw": NumpyFFTWBackend,
|
58
60
|
"pytorch": PytorchBackend,
|
59
61
|
"cupy": CupyBackend,
|
60
62
|
"mlx": MLXBackend,
|
63
|
+
"jax": JaxBackend,
|
61
64
|
}
|
62
65
|
self._backend = NumpyFFTWBackend()
|
63
|
-
self._backend_name = "
|
66
|
+
self._backend_name = "numpyfftw"
|
64
67
|
self._backend_args = {}
|
65
68
|
|
66
69
|
def __repr__(self):
|
@@ -99,7 +102,8 @@ class BackendManager:
|
|
99
102
|
Raises
|
100
103
|
------
|
101
104
|
ValueError
|
102
|
-
If the provided backend_instance does not inherit from
|
105
|
+
If the provided backend_instance does not inherit from
|
106
|
+
:py:class:`MatchingBackend`.
|
103
107
|
"""
|
104
108
|
if not issubclass(backend_class, MatchingBackend):
|
105
109
|
raise ValueError("backend_class needs to inherit from MatchingBackend.")
|
@@ -122,9 +126,7 @@ class BackendManager:
|
|
122
126
|
If no backend is found with the provided name.
|
123
127
|
"""
|
124
128
|
if backend_name not in self._BACKEND_REGISTRY:
|
125
|
-
available_backends = ", ".join(
|
126
|
-
[str(x) for x in self._BACKEND_REGISTRY.keys()]
|
127
|
-
)
|
129
|
+
available_backends = ", ".join(self.available_backends())
|
128
130
|
raise NotImplementedError(
|
129
131
|
f"Available backends are {available_backends} - not {backend_name}."
|
130
132
|
)
|
@@ -132,5 +134,32 @@ class BackendManager:
|
|
132
134
|
self._backend_name = backend_name
|
133
135
|
self._backend_args = backend_kwargs
|
134
136
|
|
137
|
+
def available_backends(self) -> List[str]:
|
138
|
+
"""
|
139
|
+
Determines importable backends.
|
140
|
+
|
141
|
+
Returns
|
142
|
+
-------
|
143
|
+
list of str
|
144
|
+
Backends that are available for template matching.
|
145
|
+
"""
|
146
|
+
# This is an approximation but avoids runtime polution
|
147
|
+
_dependencies = {
|
148
|
+
"numpyfftw": "numpy",
|
149
|
+
"cupy": "cupy",
|
150
|
+
"pytorch": "pytorch",
|
151
|
+
"mlx": "mlx",
|
152
|
+
"jax": "jax",
|
153
|
+
}
|
154
|
+
available_backends = []
|
155
|
+
for name, backend in self._BACKEND_REGISTRY.items():
|
156
|
+
if name not in _dependencies:
|
157
|
+
continue
|
158
|
+
|
159
|
+
if find_spec(_dependencies[name]) is not None:
|
160
|
+
available_backends.append(name)
|
161
|
+
|
162
|
+
return available_backends
|
163
|
+
|
135
164
|
|
136
165
|
backend = BackendManager()
|
@@ -0,0 +1,185 @@
|
|
1
|
+
""" Utility functions for jax backend.
|
2
|
+
|
3
|
+
Copyright (c) 2023-2024 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
from typing import Tuple
|
8
|
+
from functools import partial
|
9
|
+
|
10
|
+
import jax.numpy as jnp
|
11
|
+
from jax import pmap, lax
|
12
|
+
|
13
|
+
from ..types import BackendArray
|
14
|
+
from ..backends import backend as be
|
15
|
+
from ..matching_utils import normalize_template as _normalize_template
|
16
|
+
|
17
|
+
|
18
|
+
def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
|
19
|
+
"""
|
20
|
+
Computes :py:meth:`tme.matching_exhaustive.cc_setup`.
|
21
|
+
"""
|
22
|
+
template_ft = jnp.fft.rfftn(template)
|
23
|
+
template_ft = template_ft.at[:].multiply(ft_target)
|
24
|
+
correlation = jnp.fft.irfftn(template_ft)
|
25
|
+
return correlation
|
26
|
+
|
27
|
+
|
28
|
+
def _flc_scoring(
|
29
|
+
template: BackendArray,
|
30
|
+
template_mask: BackendArray,
|
31
|
+
ft_target: BackendArray,
|
32
|
+
ft_target2: BackendArray,
|
33
|
+
n_observations: BackendArray,
|
34
|
+
eps: float,
|
35
|
+
**kwargs,
|
36
|
+
) -> BackendArray:
|
37
|
+
"""
|
38
|
+
Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
|
39
|
+
"""
|
40
|
+
correlation = _correlate(template=template, ft_target=ft_target)
|
41
|
+
inv_denominator = _reciprocal_target_std(
|
42
|
+
ft_target=ft_target,
|
43
|
+
ft_target2=ft_target2,
|
44
|
+
template_mask=template_mask,
|
45
|
+
eps=eps,
|
46
|
+
n_observations=n_observations,
|
47
|
+
)
|
48
|
+
correlation = correlation.at[:].multiply(inv_denominator)
|
49
|
+
return correlation
|
50
|
+
|
51
|
+
|
52
|
+
def _flcSphere_scoring(
|
53
|
+
template: BackendArray,
|
54
|
+
ft_target: BackendArray,
|
55
|
+
inv_denominator: BackendArray,
|
56
|
+
**kwargs,
|
57
|
+
) -> BackendArray:
|
58
|
+
"""
|
59
|
+
Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
|
60
|
+
"""
|
61
|
+
correlation = _correlate(template=template, ft_target=ft_target)
|
62
|
+
correlation = correlation.at[:].multiply(inv_denominator)
|
63
|
+
return correlation
|
64
|
+
|
65
|
+
|
66
|
+
def _reciprocal_target_std(
|
67
|
+
ft_target: BackendArray,
|
68
|
+
ft_target2: BackendArray,
|
69
|
+
template_mask: BackendArray,
|
70
|
+
n_observations: float,
|
71
|
+
eps: float,
|
72
|
+
) -> BackendArray:
|
73
|
+
"""
|
74
|
+
Computes reciprocal standard deviation of a target given a mask.
|
75
|
+
|
76
|
+
See Also
|
77
|
+
--------
|
78
|
+
:py:meth:`tme.matching_exhaustive.flc_scoring`.
|
79
|
+
"""
|
80
|
+
ft_template_mask = jnp.fft.rfftn(template_mask)
|
81
|
+
|
82
|
+
# E(X^2)- E(X)^2
|
83
|
+
exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask)
|
84
|
+
exp_sq = exp_sq.at[:].divide(n_observations)
|
85
|
+
|
86
|
+
ft_template_mask = ft_template_mask.at[:].multiply(ft_target)
|
87
|
+
sq_exp = jnp.fft.irfftn(ft_template_mask)
|
88
|
+
sq_exp = sq_exp.at[:].divide(n_observations)
|
89
|
+
sq_exp = sq_exp.at[:].power(2)
|
90
|
+
|
91
|
+
exp_sq = exp_sq.at[:].add(-sq_exp)
|
92
|
+
exp_sq = exp_sq.at[:].max(0)
|
93
|
+
exp_sq = exp_sq.at[:].power(0.5)
|
94
|
+
|
95
|
+
exp_sq = exp_sq.at[:].set(
|
96
|
+
jnp.where(exp_sq <= eps, 0, jnp.reciprocal(exp_sq * n_observations))
|
97
|
+
)
|
98
|
+
return exp_sq
|
99
|
+
|
100
|
+
|
101
|
+
def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
|
102
|
+
arr_ft = jnp.fft.rfftn(arr)
|
103
|
+
arr_ft = arr_ft.at[:].multiply(arr_filter)
|
104
|
+
return arr.at[:].set(jnp.fft.irfftn(arr_ft, s=arr.shape))
|
105
|
+
|
106
|
+
|
107
|
+
def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
|
108
|
+
return arr
|
109
|
+
|
110
|
+
@partial(
|
111
|
+
pmap,
|
112
|
+
in_axes=(0,) + (None,) * 6,
|
113
|
+
static_broadcasted_argnums=[6, 7],
|
114
|
+
)
|
115
|
+
def scan(
|
116
|
+
target: BackendArray,
|
117
|
+
template: BackendArray,
|
118
|
+
template_mask: BackendArray,
|
119
|
+
rotations: BackendArray,
|
120
|
+
template_filter: BackendArray,
|
121
|
+
target_filter: BackendArray,
|
122
|
+
fast_shape: Tuple[int],
|
123
|
+
rotate_mask: bool,
|
124
|
+
) -> Tuple[BackendArray, BackendArray]:
|
125
|
+
eps = jnp.finfo(template.dtype).resolution
|
126
|
+
|
127
|
+
if hasattr(target_filter, "shape"):
|
128
|
+
target = _apply_fourier_filter(target, target_filter)
|
129
|
+
|
130
|
+
ft_target = jnp.fft.rfftn(target)
|
131
|
+
ft_target2 = jnp.fft.rfftn(jnp.square(target))
|
132
|
+
inv_denominator, target, scoring_func = None, None, _flc_scoring
|
133
|
+
if not rotate_mask:
|
134
|
+
n_observations = jnp.sum(template_mask)
|
135
|
+
inv_denominator = _reciprocal_target_std(
|
136
|
+
ft_target=ft_target,
|
137
|
+
ft_target2=ft_target2,
|
138
|
+
template_mask=be.topleft_pad(template_mask, fast_shape),
|
139
|
+
eps=eps,
|
140
|
+
n_observations=n_observations,
|
141
|
+
)
|
142
|
+
ft_target2, scoring_func = None, _flcSphere_scoring
|
143
|
+
|
144
|
+
_template_filter_func = _identity
|
145
|
+
if template_filter.shape != ():
|
146
|
+
_template_filter_func = _apply_fourier_filter
|
147
|
+
|
148
|
+
def _sample_transform(ret, rotation_matrix):
|
149
|
+
max_scores, rotations, index = ret
|
150
|
+
template_rot, template_mask_rot = be.rigid_transform(
|
151
|
+
arr=template,
|
152
|
+
arr_mask=template_mask,
|
153
|
+
rotation_matrix=rotation_matrix,
|
154
|
+
order=1, # thats all we get for now
|
155
|
+
)
|
156
|
+
|
157
|
+
n_observations = jnp.sum(template_mask_rot)
|
158
|
+
template_rot = _template_filter_func(template_rot, template_filter)
|
159
|
+
template_rot = _normalize_template(
|
160
|
+
template_rot, template_mask_rot, n_observations
|
161
|
+
)
|
162
|
+
template_rot = be.topleft_pad(template_rot, fast_shape)
|
163
|
+
template_mask_rot = be.topleft_pad(template_mask_rot, fast_shape)
|
164
|
+
|
165
|
+
scores = scoring_func(
|
166
|
+
template=template_rot,
|
167
|
+
template_mask=template_mask_rot,
|
168
|
+
ft_target=ft_target,
|
169
|
+
ft_target2=ft_target2,
|
170
|
+
inv_denominator=inv_denominator,
|
171
|
+
n_observations=n_observations,
|
172
|
+
eps=eps,
|
173
|
+
)
|
174
|
+
max_scores, rotations = be.max_score_over_rotations(
|
175
|
+
scores, max_scores, rotations, index
|
176
|
+
)
|
177
|
+
return (max_scores, rotations, index + 1), None
|
178
|
+
|
179
|
+
score_space = jnp.zeros(fast_shape)
|
180
|
+
rotation_space = jnp.full(shape=fast_shape, dtype=jnp.int32, fill_value=-1)
|
181
|
+
(score_space, rotation_space, _), _ = lax.scan(
|
182
|
+
_sample_transform, (score_space, rotation_space, 0), rotations
|
183
|
+
)
|
184
|
+
|
185
|
+
return score_space, rotation_space
|