pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
- pytme-0.2.3.data/scripts/preprocess.py +132 -0
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
- pytme-0.2.3.dist-info/METADATA +92 -0
- pytme-0.2.3.dist-info/RECORD +75 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
- pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
- scripts/extract_candidates.py +20 -13
- scripts/match_template.py +219 -216
- scripts/match_template_filters.py +154 -95
- scripts/postprocess.py +86 -54
- scripts/preprocess.py +95 -56
- scripts/preprocessor_gui.py +181 -94
- scripts/refine_matches.py +265 -61
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +458 -813
- tme/backends/__init__.py +40 -11
- tme/backends/_jax_utils.py +187 -0
- tme/backends/cupy_backend.py +109 -226
- tme/backends/jax_backend.py +230 -152
- tme/backends/matching_backend.py +445 -384
- tme/backends/mlx_backend.py +32 -59
- tme/backends/npfftw_backend.py +240 -507
- tme/backends/pytorch_backend.py +30 -151
- tme/density.py +248 -371
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +328 -284
- tme/matching_exhaustive.py +195 -1499
- tme/matching_optimization.py +143 -106
- tme/matching_scores.py +887 -0
- tme/matching_utils.py +287 -388
- tme/memory.py +377 -0
- tme/orientations.py +78 -21
- tme/parser.py +3 -4
- tme/preprocessing/_utils.py +61 -32
- tme/preprocessing/composable_filter.py +7 -4
- tme/preprocessing/compose.py +7 -3
- tme/preprocessing/frequency_filters.py +49 -39
- tme/preprocessing/tilt_series.py +44 -72
- tme/preprocessor.py +560 -526
- tme/structure.py +491 -188
- tme/types.py +5 -3
- pytme-0.2.1.dist-info/METADATA +0 -73
- pytme-0.2.1.dist-info/RECORD +0 -73
- tme/helpers.py +0 -881
- tme/matching_constrained.py +0 -195
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
tme/backends/jax_backend.py
CHANGED
@@ -1,17 +1,38 @@
|
|
1
1
|
""" Backend using jax for template matching.
|
2
2
|
|
3
|
-
Copyright (c) 2023 European Molecular Biology Laboratory
|
3
|
+
Copyright (c) 2023-2024 European Molecular Biology Laboratory
|
4
4
|
|
5
5
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
6
|
"""
|
7
|
-
from
|
7
|
+
from functools import wraps
|
8
|
+
from typing import Tuple, List, Callable
|
9
|
+
|
10
|
+
from ..types import BackendArray
|
11
|
+
from .npfftw_backend import NumpyFFTWBackend, shm_type
|
12
|
+
|
13
|
+
|
14
|
+
def emulate_out(func):
|
15
|
+
"""
|
16
|
+
Adds an out argument to write output of ``func`` to.
|
17
|
+
"""
|
18
|
+
|
19
|
+
@wraps(func)
|
20
|
+
def inner(*args, out=None, **kwargs):
|
21
|
+
ret = func(*args, **kwargs)
|
22
|
+
if out is not None:
|
23
|
+
out = out.at[:].set(ret)
|
24
|
+
return out
|
25
|
+
return ret
|
26
|
+
|
27
|
+
return inner
|
8
28
|
|
9
|
-
from .npfftw_backend import NumpyFFTWBackend
|
10
29
|
|
11
30
|
class JaxBackend(NumpyFFTWBackend):
|
12
|
-
|
13
|
-
|
14
|
-
|
31
|
+
"""
|
32
|
+
A jax-based matching backend.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self, float_dtype=None, complex_dtype=None, int_dtype=None, **kwargs):
|
15
36
|
import jax.scipy as jsp
|
16
37
|
import jax.numpy as jnp
|
17
38
|
|
@@ -19,24 +40,33 @@ class JaxBackend(NumpyFFTWBackend):
|
|
19
40
|
complex_dtype = jnp.complex64 if complex_dtype is None else complex_dtype
|
20
41
|
int_dtype = jnp.int32 if int_dtype is None else int_dtype
|
21
42
|
|
22
|
-
self.scipy = jsp
|
23
43
|
super().__init__(
|
24
44
|
array_backend=jnp,
|
25
45
|
float_dtype=float_dtype,
|
26
46
|
complex_dtype=complex_dtype,
|
27
47
|
int_dtype=int_dtype,
|
48
|
+
overflow_safe_dtype=float_dtype,
|
28
49
|
)
|
50
|
+
self.scipy = jsp
|
51
|
+
self._create_ufuncs()
|
52
|
+
try:
|
53
|
+
from ._jax_utils import scan as _
|
54
|
+
|
55
|
+
self.scan = self._scan
|
56
|
+
except Exception:
|
57
|
+
pass
|
29
58
|
|
30
|
-
def
|
31
|
-
return
|
59
|
+
def from_sharedarr(self, arr: BackendArray) -> BackendArray:
|
60
|
+
return arr
|
32
61
|
|
33
|
-
|
34
|
-
|
62
|
+
@staticmethod
|
63
|
+
def to_sharedarr(arr: BackendArray, shared_memory_handler: type = None) -> shm_type:
|
35
64
|
return arr
|
36
65
|
|
37
|
-
def topleft_pad(
|
38
|
-
|
39
|
-
|
66
|
+
def topleft_pad(
|
67
|
+
self, arr: BackendArray, shape: Tuple[int], padval: int = 0
|
68
|
+
) -> BackendArray:
|
69
|
+
b = self.full(shape=shape, dtype=arr.dtype, fill_value=padval)
|
40
70
|
aind = [slice(None, None)] * arr.ndim
|
41
71
|
bind = [slice(None, None)] * arr.ndim
|
42
72
|
for i in range(arr.ndim):
|
@@ -47,43 +77,29 @@ class JaxBackend(NumpyFFTWBackend):
|
|
47
77
|
b = b.at[tuple(bind)].set(arr[tuple(aind)])
|
48
78
|
return b
|
49
79
|
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
out = out.at[:].set(ret)
|
74
|
-
return ret
|
75
|
-
|
76
|
-
def divide(self, x1, x2, out = None, *args, **kwargs):
|
77
|
-
x1 = self.to_backend_array(x1)
|
78
|
-
x2 = self.to_backend_array(x2)
|
79
|
-
ret = self._array_backend.divide(x1, x2, *args, **kwargs)
|
80
|
-
if out is not None:
|
81
|
-
out = out.at[:].set(ret)
|
82
|
-
return ret
|
83
|
-
|
84
|
-
def fill(self, arr, value: float) -> None:
|
85
|
-
arr.at[:].set(value)
|
86
|
-
|
80
|
+
def _create_ufuncs(self):
|
81
|
+
ufuncs = [
|
82
|
+
"add",
|
83
|
+
"subtract",
|
84
|
+
"multiply",
|
85
|
+
"divide",
|
86
|
+
"square",
|
87
|
+
"sqrt",
|
88
|
+
"maximum",
|
89
|
+
]
|
90
|
+
for ufunc in ufuncs:
|
91
|
+
backend_method = emulate_out(getattr(self._array_backend, ufunc))
|
92
|
+
setattr(self, ufunc, staticmethod(backend_method))
|
93
|
+
|
94
|
+
ufuncs = ["zeros", "full"]
|
95
|
+
for ufunc in ufuncs:
|
96
|
+
backend_method = getattr(self._array_backend, ufunc)
|
97
|
+
setattr(self, ufunc, staticmethod(backend_method))
|
98
|
+
|
99
|
+
def fill(self, arr: BackendArray, value: float) -> BackendArray:
|
100
|
+
return self._array_backend.full(
|
101
|
+
shape=arr.shape, dtype=arr.dtype, fill_value=value
|
102
|
+
)
|
87
103
|
|
88
104
|
def build_fft(
|
89
105
|
self,
|
@@ -92,127 +108,189 @@ class JaxBackend(NumpyFFTWBackend):
|
|
92
108
|
inverse_fast_shape: Tuple[int] = None,
|
93
109
|
**kwargs,
|
94
110
|
) -> Tuple[Callable, Callable]:
|
95
|
-
"""
|
96
|
-
Build fft builder functions.
|
97
|
-
|
98
|
-
Parameters
|
99
|
-
----------
|
100
|
-
fast_shape : tuple
|
101
|
-
Tuple of integers corresponding to fast convolution shape
|
102
|
-
(see :py:meth:`PytorchBackend.compute_convolution_shapes`).
|
103
|
-
fast_ft_shape : tuple
|
104
|
-
Tuple of integers corresponding to the shape of the Fourier
|
105
|
-
transform array (see :py:meth:`PytorchBackend.compute_convolution_shapes`).
|
106
|
-
inverse_fast_shape : tuple, optional
|
107
|
-
Output shape of the inverse Fourier transform. By default fast_shape.
|
108
|
-
**kwargs : dict, optional
|
109
|
-
Unused keyword arguments.
|
110
|
-
|
111
|
-
Returns
|
112
|
-
-------
|
113
|
-
tuple
|
114
|
-
Tupple containing callable rfft and irfft object.
|
115
|
-
"""
|
116
111
|
if inverse_fast_shape is None:
|
117
112
|
inverse_fast_shape = fast_shape
|
118
113
|
|
119
|
-
def rfftn(
|
120
|
-
arr,
|
121
|
-
) -> None:
|
122
|
-
out = out.at[:].set(self._array_backend.fft.rfftn(arr, s=shape))
|
114
|
+
def rfftn(arr, out, shape=fast_shape):
|
115
|
+
return self._array_backend.fft.rfftn(arr, s=shape)
|
123
116
|
|
124
|
-
def irfftn(
|
125
|
-
arr,
|
126
|
-
) -> None:
|
127
|
-
out = out.at[:].set(self._array_backend.fft.irfftn(arr, s=shape))
|
117
|
+
def irfftn(arr, out, shape=fast_shape):
|
118
|
+
return self._array_backend.fft.irfftn(arr, s=shape)
|
128
119
|
|
129
120
|
return rfftn, irfftn
|
130
121
|
|
131
|
-
def
|
132
|
-
return shm
|
133
|
-
|
134
|
-
@staticmethod
|
135
|
-
def arr_to_sharedarr(arr, shared_memory_handler: type = None):
|
136
|
-
return arr
|
137
|
-
|
138
|
-
def rotate_array(
|
122
|
+
def rigid_transform(
|
139
123
|
self,
|
140
|
-
arr,
|
141
|
-
rotation_matrix,
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
) ->
|
124
|
+
arr: BackendArray,
|
125
|
+
rotation_matrix: BackendArray,
|
126
|
+
out: BackendArray = None,
|
127
|
+
out_mask: BackendArray = None,
|
128
|
+
translation: BackendArray = None,
|
129
|
+
arr_mask: BackendArray = None,
|
130
|
+
order: int = 1,
|
131
|
+
**kwargs,
|
132
|
+
) -> Tuple[BackendArray, BackendArray]:
|
149
133
|
rotate_mask = arr_mask is not None
|
150
|
-
return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
|
151
|
-
translation = self.zeros(arr.ndim) if translation is None else translation
|
152
134
|
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
center = self.divide(arr.shape, 2)
|
158
|
-
if not use_geometric_center:
|
159
|
-
center = self.center_of_mass(arr, cutoff=0)
|
160
|
-
center = center[:, None]
|
135
|
+
center = self.divide(self.to_backend_array(arr.shape) - 1, 2)[:, None]
|
136
|
+
indices = self._array_backend.indices(arr.shape, dtype=self._float_dtype)
|
137
|
+
indices = indices.reshape((arr.ndim, -1))
|
161
138
|
indices = indices.at[:].add(-center)
|
162
|
-
|
163
|
-
indices = self._array_backend.matmul(rotation_matrix, indices)
|
139
|
+
indices = self._array_backend.matmul(rotation_matrix.T, indices)
|
164
140
|
indices = indices.at[:].add(center)
|
141
|
+
if translation is not None:
|
142
|
+
indices = indices.at[:].add(translation)
|
165
143
|
|
166
|
-
out = self.
|
167
|
-
|
168
|
-
|
169
|
-
out = out.at[out_slice].set(
|
170
|
-
self.scipy.ndimage.map_coordinates(
|
171
|
-
arr, indices, order=order
|
172
|
-
).reshape(arr.shape)
|
144
|
+
out = self.scipy.ndimage.map_coordinates(arr, indices, order=order).reshape(
|
145
|
+
arr.shape
|
173
146
|
)
|
174
147
|
|
148
|
+
out_mask = arr_mask
|
175
149
|
if rotate_mask:
|
176
|
-
out_mask = self.
|
177
|
-
|
178
|
-
|
179
|
-
self.scipy.ndimage.map_coordinates(
|
180
|
-
arr_mask, indices, order=order
|
181
|
-
).reshape(arr.shape)
|
182
|
-
)
|
150
|
+
out_mask = self.scipy.ndimage.map_coordinates(
|
151
|
+
arr_mask, indices, order=order
|
152
|
+
).reshape(arr_mask.shape)
|
183
153
|
|
184
|
-
|
185
|
-
case 0:
|
186
|
-
return None
|
187
|
-
case 1:
|
188
|
-
return out
|
189
|
-
case 2:
|
190
|
-
return out_mask
|
191
|
-
case 3:
|
192
|
-
return out, out_mask
|
154
|
+
return out, out_mask
|
193
155
|
|
194
156
|
def max_score_over_rotations(
|
195
157
|
self,
|
196
|
-
|
197
|
-
|
198
|
-
|
158
|
+
scores: BackendArray,
|
159
|
+
max_scores: BackendArray,
|
160
|
+
rotations: BackendArray,
|
199
161
|
rotation_index: int,
|
200
|
-
):
|
162
|
+
) -> Tuple[BackendArray, BackendArray]:
|
163
|
+
update = self.greater(max_scores, scores)
|
164
|
+
max_scores = max_scores.at[:].set(self.where(update, max_scores, scores))
|
165
|
+
rotations = rotations.at[:].set(self.where(update, rotations, rotation_index))
|
166
|
+
return max_scores, rotations
|
167
|
+
|
168
|
+
def _scan(
|
169
|
+
self,
|
170
|
+
matching_data: type,
|
171
|
+
splits: Tuple[Tuple[slice, slice]],
|
172
|
+
n_jobs: int,
|
173
|
+
callback_class,
|
174
|
+
rotate_mask: bool = False,
|
175
|
+
**kwargs,
|
176
|
+
) -> List:
|
201
177
|
"""
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
Parameters
|
206
|
-
----------
|
207
|
-
score_space : CupyArray
|
208
|
-
The score space to compare against internal_scores.
|
209
|
-
internal_scores : CupyArray
|
210
|
-
The internal scores to update with maximum scores.
|
211
|
-
internal_rotations : CupyArray
|
212
|
-
The internal rotations corresponding to the maximum scores.
|
213
|
-
rotation_index : int
|
214
|
-
The index representing the current rotation.
|
178
|
+
Emulates output of :py:meth:`tme.matching_exhaustive.scan` using
|
179
|
+
:py:class:`tme.analyzer.MaxScoreOverRotations`.
|
215
180
|
"""
|
216
|
-
|
217
|
-
|
218
|
-
|
181
|
+
from ._jax_utils import scan as scan_inner
|
182
|
+
|
183
|
+
pad_target = True if len(splits) > 1 else False
|
184
|
+
convolution_mode = "valid" if pad_target else "same"
|
185
|
+
target_pad = matching_data.target_padding(pad_target=pad_target)
|
186
|
+
|
187
|
+
target_shape = tuple(
|
188
|
+
(x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad)
|
189
|
+
)
|
190
|
+
conv_shape, fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
|
191
|
+
target_shape=self.to_numpy_array(target_shape),
|
192
|
+
template_shape=self.to_numpy_array(matching_data._template.shape),
|
193
|
+
pad_fourier=False,
|
194
|
+
)
|
195
|
+
|
196
|
+
analyzer_args = {
|
197
|
+
"convolution_mode": convolution_mode,
|
198
|
+
"fourier_shift": shift,
|
199
|
+
"targetshape": target_shape,
|
200
|
+
"templateshape": matching_data.template.shape,
|
201
|
+
"convolution_shape": conv_shape,
|
202
|
+
}
|
203
|
+
|
204
|
+
create_target_filter = matching_data.target_filter is not None
|
205
|
+
create_template_filter = matching_data.template_filter is not None
|
206
|
+
create_filter = create_target_filter or create_template_filter
|
207
|
+
|
208
|
+
# Applying the filter leads to more FFTs
|
209
|
+
fastt_shape = matching_data._template.shape
|
210
|
+
if create_template_filter:
|
211
|
+
_, fastt_shape, _, tshift = matching_data._fourier_padding(
|
212
|
+
target_shape=self.to_numpy_array(matching_data._template.shape),
|
213
|
+
template_shape=self.to_numpy_array(
|
214
|
+
[1 for _ in matching_data._template.shape]
|
215
|
+
),
|
216
|
+
pad_fourier=False,
|
217
|
+
)
|
218
|
+
|
219
|
+
ret, template_filter, target_filter = [], 1, 1
|
220
|
+
rotation_mapping = {
|
221
|
+
self.tobytes(matching_data.rotations[i]): i
|
222
|
+
for i in range(matching_data.rotations.shape[0])
|
223
|
+
}
|
224
|
+
for split_start in range(0, len(splits), n_jobs):
|
225
|
+
split_subset = splits[split_start : (split_start + n_jobs)]
|
226
|
+
if not len(split_subset):
|
227
|
+
continue
|
228
|
+
|
229
|
+
targets, translation_offsets = [], []
|
230
|
+
for target_split, template_split in split_subset:
|
231
|
+
base = matching_data.subset_by_slice(
|
232
|
+
target_slice=target_split,
|
233
|
+
target_pad=target_pad,
|
234
|
+
template_slice=template_split,
|
235
|
+
)
|
236
|
+
translation_offsets.append(base._translation_offset)
|
237
|
+
targets.append(self.topleft_pad(base._target, fast_shape))
|
238
|
+
|
239
|
+
if create_filter:
|
240
|
+
filter_args = {
|
241
|
+
"data_rfft": self.fft.rfftn(targets[0]),
|
242
|
+
"return_real_fourier": True,
|
243
|
+
"shape_is_real_fourier": False,
|
244
|
+
}
|
245
|
+
|
246
|
+
if create_template_filter:
|
247
|
+
template_filter = matching_data.template_filter(
|
248
|
+
shape=fastt_shape, **filter_args
|
249
|
+
)["data"]
|
250
|
+
template_filter = template_filter.at[(0,) * template_filter.ndim].set(0)
|
251
|
+
|
252
|
+
if create_target_filter:
|
253
|
+
target_filter = matching_data.target_filter(
|
254
|
+
shape=fast_shape, **filter_args
|
255
|
+
)["data"]
|
256
|
+
target_filter = target_filter.at[(0,) * target_filter.ndim].set(0)
|
257
|
+
|
258
|
+
create_filter, create_template_filter, create_target_filter = (False,) * 3
|
259
|
+
base, targets = None, self._array_backend.stack(targets)
|
260
|
+
scores, rotations = scan_inner(
|
261
|
+
targets,
|
262
|
+
self.topleft_pad(matching_data.template, fastt_shape),
|
263
|
+
self.topleft_pad(matching_data.template_mask, fastt_shape),
|
264
|
+
matching_data.rotations,
|
265
|
+
template_filter,
|
266
|
+
target_filter,
|
267
|
+
fast_shape,
|
268
|
+
rotate_mask,
|
269
|
+
)
|
270
|
+
|
271
|
+
for index in range(scores.shape[0]):
|
272
|
+
temp = callback_class(
|
273
|
+
scores=scores[index],
|
274
|
+
rotations=rotations[index],
|
275
|
+
thread_safe=False,
|
276
|
+
offset=translation_offsets[index],
|
277
|
+
)
|
278
|
+
temp.rotation_mapping = rotation_mapping
|
279
|
+
ret.append(tuple(temp._postprocess(**analyzer_args)))
|
280
|
+
|
281
|
+
return ret
|
282
|
+
|
283
|
+
def get_available_memory(self) -> int:
|
284
|
+
import jax
|
285
|
+
|
286
|
+
_memory = {"cpu": 0, "gpu": 0}
|
287
|
+
for device in jax.devices():
|
288
|
+
if device.platform == "cpu":
|
289
|
+
_memory["cpu"] = super().get_available_memory()
|
290
|
+
else:
|
291
|
+
mem_stats = device.memory_stats()
|
292
|
+
_memory["gpu"] += mem_stats.get("bytes_limit", 0)
|
293
|
+
|
294
|
+
if _memory["gpu"] > 0:
|
295
|
+
return _memory["gpu"]
|
296
|
+
return _memory["cpu"]
|