pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/match_template.py +147 -93
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/postprocess.py +67 -26
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +175 -85
- pytme-0.2.2.dist-info/METADATA +91 -0
- pytme-0.2.2.dist-info/RECORD +74 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
- scripts/extract_candidates.py +20 -13
- scripts/match_template.py +147 -93
- scripts/match_template_filters.py +154 -95
- scripts/postprocess.py +67 -26
- scripts/preprocessor_gui.py +175 -85
- scripts/refine_matches.py +265 -61
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +451 -809
- tme/backends/__init__.py +40 -11
- tme/backends/_jax_utils.py +185 -0
- tme/backends/cupy_backend.py +111 -223
- tme/backends/jax_backend.py +214 -150
- tme/backends/matching_backend.py +445 -384
- tme/backends/mlx_backend.py +32 -59
- tme/backends/npfftw_backend.py +239 -507
- tme/backends/pytorch_backend.py +21 -145
- tme/density.py +233 -363
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +322 -285
- tme/matching_exhaustive.py +172 -1493
- tme/matching_optimization.py +143 -106
- tme/matching_scores.py +884 -0
- tme/matching_utils.py +280 -386
- tme/memory.py +377 -0
- tme/orientations.py +52 -12
- tme/parser.py +3 -4
- tme/preprocessing/_utils.py +61 -32
- tme/preprocessing/compose.py +7 -3
- tme/preprocessing/frequency_filters.py +49 -39
- tme/preprocessing/tilt_series.py +34 -40
- tme/preprocessor.py +560 -526
- tme/structure.py +491 -188
- tme/types.py +5 -3
- pytme-0.2.1.dist-info/METADATA +0 -73
- pytme-0.2.1.dist-info/RECORD +0 -73
- tme/helpers.py +0 -881
- tme/matching_constrained.py +0 -195
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
tme/backends/jax_backend.py
CHANGED
@@ -1,17 +1,38 @@
|
|
1
1
|
""" Backend using jax for template matching.
|
2
2
|
|
3
|
-
Copyright (c) 2023 European Molecular Biology Laboratory
|
3
|
+
Copyright (c) 2023-2024 European Molecular Biology Laboratory
|
4
4
|
|
5
5
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
6
|
"""
|
7
|
-
from
|
7
|
+
from functools import wraps
|
8
|
+
from typing import Tuple, List, Callable
|
9
|
+
|
10
|
+
from ..types import BackendArray
|
11
|
+
from .npfftw_backend import NumpyFFTWBackend, shm_type
|
12
|
+
|
13
|
+
|
14
|
+
def emulate_out(func):
|
15
|
+
"""
|
16
|
+
Adds an out argument to write output of ``func`` to.
|
17
|
+
"""
|
18
|
+
|
19
|
+
@wraps(func)
|
20
|
+
def inner(*args, out=None, **kwargs):
|
21
|
+
ret = func(*args, **kwargs)
|
22
|
+
if out is not None:
|
23
|
+
out = out.at[:].set(ret)
|
24
|
+
return out
|
25
|
+
return ret
|
26
|
+
|
27
|
+
return inner
|
8
28
|
|
9
|
-
from .npfftw_backend import NumpyFFTWBackend
|
10
29
|
|
11
30
|
class JaxBackend(NumpyFFTWBackend):
|
12
|
-
|
13
|
-
|
14
|
-
|
31
|
+
"""
|
32
|
+
A jax-based matching backend.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self, float_dtype=None, complex_dtype=None, int_dtype=None, **kwargs):
|
15
36
|
import jax.scipy as jsp
|
16
37
|
import jax.numpy as jnp
|
17
38
|
|
@@ -19,24 +40,33 @@ class JaxBackend(NumpyFFTWBackend):
|
|
19
40
|
complex_dtype = jnp.complex64 if complex_dtype is None else complex_dtype
|
20
41
|
int_dtype = jnp.int32 if int_dtype is None else int_dtype
|
21
42
|
|
22
|
-
self.scipy = jsp
|
23
43
|
super().__init__(
|
24
44
|
array_backend=jnp,
|
25
45
|
float_dtype=float_dtype,
|
26
46
|
complex_dtype=complex_dtype,
|
27
47
|
int_dtype=int_dtype,
|
48
|
+
overflow_safe_dtype=float_dtype,
|
28
49
|
)
|
50
|
+
self.scipy = jsp
|
51
|
+
self._create_ufuncs()
|
52
|
+
try:
|
53
|
+
from ._jax_utils import scan as _
|
29
54
|
|
30
|
-
|
31
|
-
|
55
|
+
self.scan = self._scan
|
56
|
+
except Exception:
|
57
|
+
pass
|
32
58
|
|
33
|
-
def
|
34
|
-
arr
|
59
|
+
def from_sharedarr(self, arr: BackendArray) -> BackendArray:
|
60
|
+
return arr
|
61
|
+
|
62
|
+
@staticmethod
|
63
|
+
def to_sharedarr(arr: BackendArray, shared_memory_handler: type = None) -> shm_type:
|
35
64
|
return arr
|
36
65
|
|
37
|
-
def topleft_pad(
|
38
|
-
|
39
|
-
|
66
|
+
def topleft_pad(
|
67
|
+
self, arr: BackendArray, shape: Tuple[int], padval: int = 0
|
68
|
+
) -> BackendArray:
|
69
|
+
b = self.full(shape=shape, dtype=arr.dtype, fill_value=padval)
|
40
70
|
aind = [slice(None, None)] * arr.ndim
|
41
71
|
bind = [slice(None, None)] * arr.ndim
|
42
72
|
for i in range(arr.ndim):
|
@@ -47,43 +77,29 @@ class JaxBackend(NumpyFFTWBackend):
|
|
47
77
|
b = b.at[tuple(bind)].set(arr[tuple(aind)])
|
48
78
|
return b
|
49
79
|
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
out = out.at[:].set(ret)
|
74
|
-
return ret
|
75
|
-
|
76
|
-
def divide(self, x1, x2, out = None, *args, **kwargs):
|
77
|
-
x1 = self.to_backend_array(x1)
|
78
|
-
x2 = self.to_backend_array(x2)
|
79
|
-
ret = self._array_backend.divide(x1, x2, *args, **kwargs)
|
80
|
-
if out is not None:
|
81
|
-
out = out.at[:].set(ret)
|
82
|
-
return ret
|
83
|
-
|
84
|
-
def fill(self, arr, value: float) -> None:
|
85
|
-
arr.at[:].set(value)
|
86
|
-
|
80
|
+
def _create_ufuncs(self):
|
81
|
+
ufuncs = [
|
82
|
+
"add",
|
83
|
+
"subtract",
|
84
|
+
"multiply",
|
85
|
+
"divide",
|
86
|
+
"square",
|
87
|
+
"sqrt",
|
88
|
+
"maximum",
|
89
|
+
]
|
90
|
+
for ufunc in ufuncs:
|
91
|
+
backend_method = emulate_out(getattr(self._array_backend, ufunc))
|
92
|
+
setattr(self, ufunc, staticmethod(backend_method))
|
93
|
+
|
94
|
+
ufuncs = ["zeros", "full"]
|
95
|
+
for ufunc in ufuncs:
|
96
|
+
backend_method = getattr(self._array_backend, ufunc)
|
97
|
+
setattr(self, ufunc, staticmethod(backend_method))
|
98
|
+
|
99
|
+
def fill(self, arr: BackendArray, value: float) -> BackendArray:
|
100
|
+
return self._array_backend.full(
|
101
|
+
shape=arr.shape, dtype=arr.dtype, fill_value=value
|
102
|
+
)
|
87
103
|
|
88
104
|
def build_fft(
|
89
105
|
self,
|
@@ -92,127 +108,175 @@ class JaxBackend(NumpyFFTWBackend):
|
|
92
108
|
inverse_fast_shape: Tuple[int] = None,
|
93
109
|
**kwargs,
|
94
110
|
) -> Tuple[Callable, Callable]:
|
95
|
-
"""
|
96
|
-
Build fft builder functions.
|
97
|
-
|
98
|
-
Parameters
|
99
|
-
----------
|
100
|
-
fast_shape : tuple
|
101
|
-
Tuple of integers corresponding to fast convolution shape
|
102
|
-
(see :py:meth:`PytorchBackend.compute_convolution_shapes`).
|
103
|
-
fast_ft_shape : tuple
|
104
|
-
Tuple of integers corresponding to the shape of the Fourier
|
105
|
-
transform array (see :py:meth:`PytorchBackend.compute_convolution_shapes`).
|
106
|
-
inverse_fast_shape : tuple, optional
|
107
|
-
Output shape of the inverse Fourier transform. By default fast_shape.
|
108
|
-
**kwargs : dict, optional
|
109
|
-
Unused keyword arguments.
|
110
|
-
|
111
|
-
Returns
|
112
|
-
-------
|
113
|
-
tuple
|
114
|
-
Tupple containing callable rfft and irfft object.
|
115
|
-
"""
|
116
111
|
if inverse_fast_shape is None:
|
117
112
|
inverse_fast_shape = fast_shape
|
118
113
|
|
119
|
-
def rfftn(
|
120
|
-
arr,
|
121
|
-
) -> None:
|
122
|
-
out = out.at[:].set(self._array_backend.fft.rfftn(arr, s=shape))
|
114
|
+
def rfftn(arr, out, shape=fast_shape):
|
115
|
+
return self._array_backend.fft.rfftn(arr, s=shape)
|
123
116
|
|
124
|
-
def irfftn(
|
125
|
-
arr,
|
126
|
-
) -> None:
|
127
|
-
out = out.at[:].set(self._array_backend.fft.irfftn(arr, s=shape))
|
117
|
+
def irfftn(arr, out, shape=fast_shape):
|
118
|
+
return self._array_backend.fft.irfftn(arr, s=shape)
|
128
119
|
|
129
120
|
return rfftn, irfftn
|
130
121
|
|
131
|
-
def
|
132
|
-
|
122
|
+
def compute_convolution_shapes(
|
123
|
+
self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
|
124
|
+
) -> Tuple[List[int], List[int], List[int]]:
|
125
|
+
conv_shape, fast_shape, fast_ft_shape = super().compute_convolution_shapes(
|
126
|
+
arr1_shape, arr2_shape
|
127
|
+
)
|
133
128
|
|
134
|
-
|
135
|
-
|
136
|
-
|
129
|
+
is_odd = fast_shape[-1] % 2
|
130
|
+
fast_shape[-1] += is_odd
|
131
|
+
fast_ft_shape[-1] += is_odd
|
137
132
|
|
138
|
-
|
133
|
+
return conv_shape, fast_shape, fast_ft_shape
|
134
|
+
|
135
|
+
def rigid_transform(
|
139
136
|
self,
|
140
|
-
arr,
|
141
|
-
rotation_matrix,
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
) ->
|
137
|
+
arr: BackendArray,
|
138
|
+
rotation_matrix: BackendArray,
|
139
|
+
out: BackendArray = None,
|
140
|
+
out_mask: BackendArray = None,
|
141
|
+
translation: BackendArray = None,
|
142
|
+
arr_mask: BackendArray = None,
|
143
|
+
order: int = 1,
|
144
|
+
**kwargs,
|
145
|
+
) -> Tuple[BackendArray, BackendArray]:
|
149
146
|
rotate_mask = arr_mask is not None
|
150
|
-
|
151
|
-
translation = self.zeros(arr.ndim) if translation is None else translation
|
152
|
-
|
153
|
-
indices = self._array_backend.indices(arr.shape).reshape(
|
154
|
-
(len(arr.shape), -1)
|
155
|
-
).astype(self._float_dtype)
|
147
|
+
center = self.divide(self.to_backend_array(arr.shape), 2)[:, None]
|
156
148
|
|
157
|
-
|
158
|
-
|
159
|
-
center = self.center_of_mass(arr, cutoff=0)
|
160
|
-
center = center[:, None]
|
149
|
+
indices = self._array_backend.indices(arr.shape, dtype=self._float_dtype)
|
150
|
+
indices = indices.reshape((arr.ndim, -1))
|
161
151
|
indices = indices.at[:].add(-center)
|
162
|
-
|
163
|
-
indices = self._array_backend.matmul(rotation_matrix, indices)
|
152
|
+
indices = self._array_backend.matmul(rotation_matrix.T, indices)
|
164
153
|
indices = indices.at[:].add(center)
|
154
|
+
if translation is not None:
|
155
|
+
indices = indices.at[:].add(translation)
|
165
156
|
|
166
|
-
out = self.
|
167
|
-
|
168
|
-
|
169
|
-
out = out.at[out_slice].set(
|
170
|
-
self.scipy.ndimage.map_coordinates(
|
171
|
-
arr, indices, order=order
|
172
|
-
).reshape(arr.shape)
|
157
|
+
out = self.scipy.ndimage.map_coordinates(arr, indices, order=order).reshape(
|
158
|
+
arr.shape
|
173
159
|
)
|
174
160
|
|
161
|
+
out_mask = arr_mask
|
175
162
|
if rotate_mask:
|
176
|
-
out_mask = self.
|
177
|
-
|
178
|
-
|
179
|
-
self.scipy.ndimage.map_coordinates(
|
180
|
-
arr_mask, indices, order=order
|
181
|
-
).reshape(arr.shape)
|
182
|
-
)
|
163
|
+
out_mask = self.scipy.ndimage.map_coordinates(
|
164
|
+
arr_mask, indices, order=order
|
165
|
+
).reshape(arr_mask.shape)
|
183
166
|
|
184
|
-
|
185
|
-
case 0:
|
186
|
-
return None
|
187
|
-
case 1:
|
188
|
-
return out
|
189
|
-
case 2:
|
190
|
-
return out_mask
|
191
|
-
case 3:
|
192
|
-
return out, out_mask
|
167
|
+
return out, out_mask
|
193
168
|
|
194
169
|
def max_score_over_rotations(
|
195
170
|
self,
|
196
|
-
|
197
|
-
|
198
|
-
|
171
|
+
scores: BackendArray,
|
172
|
+
max_scores: BackendArray,
|
173
|
+
rotations: BackendArray,
|
199
174
|
rotation_index: int,
|
200
|
-
):
|
175
|
+
) -> Tuple[BackendArray, BackendArray]:
|
176
|
+
update = self.greater(max_scores, scores)
|
177
|
+
max_scores = max_scores.at[:].set(self.where(update, max_scores, scores))
|
178
|
+
rotations = rotations.at[:].set(self.where(update, rotations, rotation_index))
|
179
|
+
return max_scores, rotations
|
180
|
+
|
181
|
+
def _scan(
|
182
|
+
self,
|
183
|
+
matching_data: type,
|
184
|
+
splits: Tuple[Tuple[slice, slice]],
|
185
|
+
n_jobs: int,
|
186
|
+
callback_class,
|
187
|
+
rotate_mask: bool = False,
|
188
|
+
**kwargs,
|
189
|
+
) -> List:
|
201
190
|
"""
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
Parameters
|
206
|
-
----------
|
207
|
-
score_space : CupyArray
|
208
|
-
The score space to compare against internal_scores.
|
209
|
-
internal_scores : CupyArray
|
210
|
-
The internal scores to update with maximum scores.
|
211
|
-
internal_rotations : CupyArray
|
212
|
-
The internal rotations corresponding to the maximum scores.
|
213
|
-
rotation_index : int
|
214
|
-
The index representing the current rotation.
|
191
|
+
Emulates output of :py:meth:`tme.matching_exhaustive.scan` using
|
192
|
+
:py:class:`tme.analyzer.MaxScoreOverRotations`.
|
215
193
|
"""
|
216
|
-
|
217
|
-
|
218
|
-
|
194
|
+
from ._jax_utils import scan as scan_inner
|
195
|
+
|
196
|
+
pad_target = True if len(splits) > 1 else False
|
197
|
+
convolution_mode = "valid" if pad_target else "same"
|
198
|
+
target_pad = matching_data.target_padding(pad_target=pad_target)
|
199
|
+
|
200
|
+
target_shape = tuple(
|
201
|
+
(x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad)
|
202
|
+
)
|
203
|
+
fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
|
204
|
+
target_shape=self.to_numpy_array(target_shape),
|
205
|
+
template_shape=self.to_numpy_array(matching_data._template.shape),
|
206
|
+
pad_fourier=False,
|
207
|
+
)
|
208
|
+
|
209
|
+
analyzer_args = {
|
210
|
+
"convolution_mode": convolution_mode,
|
211
|
+
"fourier_shift": shift,
|
212
|
+
"targetshape": target_shape,
|
213
|
+
"templateshape": matching_data._template.shape,
|
214
|
+
}
|
215
|
+
|
216
|
+
create_target_filter = matching_data.target_filter is not None
|
217
|
+
create_template_filter = matching_data.template_filter is not None
|
218
|
+
create_filter = create_target_filter or create_template_filter
|
219
|
+
|
220
|
+
ret, template_filter, target_filter = [], 1, 1
|
221
|
+
rotation_mapping = {
|
222
|
+
self.tobytes(matching_data.rotations[i]): i
|
223
|
+
for i in range(matching_data.rotations.shape[0])
|
224
|
+
}
|
225
|
+
for split_start in range(0, len(splits), n_jobs):
|
226
|
+
split_subset = splits[split_start : (split_start + n_jobs)]
|
227
|
+
if not len(split_subset):
|
228
|
+
continue
|
229
|
+
|
230
|
+
targets, translation_offsets = [], []
|
231
|
+
for target_split, template_split in split_subset:
|
232
|
+
base = matching_data.subset_by_slice(
|
233
|
+
target_slice=target_split,
|
234
|
+
target_pad=target_pad,
|
235
|
+
template_slice=template_split,
|
236
|
+
)
|
237
|
+
translation_offsets.append(base._translation_offset)
|
238
|
+
targets.append(self.topleft_pad(base._target, fast_shape))
|
239
|
+
|
240
|
+
if create_filter:
|
241
|
+
filter_args = {
|
242
|
+
"data_rfft": self.fft.rfftn(targets[0]),
|
243
|
+
"return_real_fourier": True,
|
244
|
+
"shape_is_real_fourier": False,
|
245
|
+
}
|
246
|
+
|
247
|
+
if create_template_filter:
|
248
|
+
template_filter = matching_data.template_filter(
|
249
|
+
shape=matching_data._template.shape, **filter_args
|
250
|
+
)["data"]
|
251
|
+
template_filter = template_filter.at[(0,) * template_filter.ndim].set(0)
|
252
|
+
|
253
|
+
if create_target_filter:
|
254
|
+
target_filter = matching_data.template_filter(
|
255
|
+
shape=fast_shape, **filter_args
|
256
|
+
)["data"]
|
257
|
+
target_filter = target_filter.at[(0,) * target_filter.ndim].set(0)
|
258
|
+
|
259
|
+
create_filter, create_template_filter, create_target_filter = (False,) * 3
|
260
|
+
base, targets = None, self._array_backend.stack(targets)
|
261
|
+
scores, rotations = scan_inner(
|
262
|
+
targets,
|
263
|
+
matching_data.template,
|
264
|
+
matching_data.template_mask,
|
265
|
+
matching_data.rotations,
|
266
|
+
template_filter,
|
267
|
+
target_filter,
|
268
|
+
fast_shape,
|
269
|
+
rotate_mask,
|
270
|
+
)
|
271
|
+
|
272
|
+
for index in range(scores.shape[0]):
|
273
|
+
temp = callback_class(
|
274
|
+
scores=scores[index],
|
275
|
+
rotations=rotations[index],
|
276
|
+
thread_safe=False,
|
277
|
+
offset=translation_offsets[index],
|
278
|
+
)
|
279
|
+
temp.rotation_mapping = rotation_mapping
|
280
|
+
ret.append(tuple(temp._postprocess(**analyzer_args)))
|
281
|
+
|
282
|
+
return ret
|