pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.2.data/scripts/match_template.py +1187 -0
- {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
- {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
- pytme-0.2.2.dist-info/METADATA +91 -0
- pytme-0.2.2.dist-info/RECORD +74 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
- scripts/extract_candidates.py +126 -87
- scripts/match_template.py +596 -209
- scripts/match_template_filters.py +571 -223
- scripts/postprocess.py +170 -71
- scripts/preprocessor_gui.py +179 -86
- scripts/refine_matches.py +567 -159
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +627 -855
- tme/backends/__init__.py +41 -11
- tme/backends/_jax_utils.py +185 -0
- tme/backends/cupy_backend.py +120 -225
- tme/backends/jax_backend.py +282 -0
- tme/backends/matching_backend.py +464 -388
- tme/backends/mlx_backend.py +45 -68
- tme/backends/npfftw_backend.py +256 -514
- tme/backends/pytorch_backend.py +41 -154
- tme/density.py +312 -421
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +366 -303
- tme/matching_exhaustive.py +279 -1521
- tme/matching_optimization.py +234 -129
- tme/matching_scores.py +884 -0
- tme/matching_utils.py +281 -387
- tme/memory.py +377 -0
- tme/orientations.py +226 -66
- tme/parser.py +3 -4
- tme/preprocessing/__init__.py +2 -0
- tme/preprocessing/_utils.py +217 -0
- tme/preprocessing/composable_filter.py +31 -0
- tme/preprocessing/compose.py +55 -0
- tme/preprocessing/frequency_filters.py +388 -0
- tme/preprocessing/tilt_series.py +1011 -0
- tme/preprocessor.py +574 -530
- tme/structure.py +495 -189
- tme/types.py +5 -3
- pytme-0.2.0b0.data/scripts/match_template.py +0 -800
- pytme-0.2.0b0.dist-info/METADATA +0 -73
- pytme-0.2.0b0.dist-info/RECORD +0 -66
- tme/helpers.py +0 -881
- tme/matching_constrained.py +0 -195
- {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
tme/backends/mlx_backend.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
""" Backend Apple's MLX library for template matching.
|
1
|
+
""" Backend using Apple's MLX library for template matching.
|
2
2
|
|
3
3
|
Copyright (c) 2024 European Molecular Biology Laboratory
|
4
4
|
|
@@ -9,34 +9,38 @@ from typing import Tuple, List, Callable
|
|
9
9
|
import numpy as np
|
10
10
|
|
11
11
|
from .npfftw_backend import NumpyFFTWBackend
|
12
|
-
from ..types import NDArray, MlxArray, Scalar
|
12
|
+
from ..types import NDArray, MlxArray, Scalar, shm_type
|
13
13
|
|
14
14
|
|
15
15
|
class MLXBackend(NumpyFFTWBackend):
|
16
16
|
"""
|
17
|
-
A
|
17
|
+
A mlx-based matching backend.
|
18
18
|
"""
|
19
19
|
|
20
20
|
def __init__(
|
21
21
|
self,
|
22
22
|
device="cpu",
|
23
|
-
|
23
|
+
float_dtype=None,
|
24
24
|
complex_dtype=None,
|
25
|
-
|
25
|
+
int_dtype=None,
|
26
|
+
overflow_safe_dtype=None,
|
26
27
|
**kwargs,
|
27
28
|
):
|
28
29
|
import mlx.core as mx
|
29
30
|
|
30
31
|
device = mx.cpu if device == "cpu" else mx.gpu
|
31
|
-
|
32
|
+
float_dtype = mx.float32 if float_dtype is None else float_dtype
|
32
33
|
complex_dtype = mx.complex64 if complex_dtype is None else complex_dtype
|
33
|
-
|
34
|
+
int_dtype = mx.int32 if int_dtype is None else int_dtype
|
35
|
+
if overflow_safe_dtype is None:
|
36
|
+
overflow_safe_dtype = mx.float32
|
34
37
|
|
35
38
|
super().__init__(
|
36
39
|
array_backend=mx,
|
37
|
-
|
40
|
+
float_dtype=float_dtype,
|
38
41
|
complex_dtype=complex_dtype,
|
39
|
-
|
42
|
+
int_dtype=int_dtype,
|
43
|
+
overflow_safe_dtype=overflow_safe_dtype,
|
40
44
|
)
|
41
45
|
|
42
46
|
self.device = device
|
@@ -68,6 +72,15 @@ class MLXBackend(NumpyFFTWBackend):
|
|
68
72
|
return None
|
69
73
|
return self._array_backend.add(x1, x2, **kwargs)
|
70
74
|
|
75
|
+
def multiply(self, x1, x2, out: MlxArray = None, **kwargs) -> MlxArray:
|
76
|
+
x1 = self.to_backend_array(x1)
|
77
|
+
x2 = self.to_backend_array(x2)
|
78
|
+
|
79
|
+
if out is not None:
|
80
|
+
out[:] = self._array_backend.multiply(x1, x2, **kwargs)
|
81
|
+
return None
|
82
|
+
return self._array_backend.multiply(x1, x2, **kwargs)
|
83
|
+
|
71
84
|
def std(self, arr: MlxArray, axis) -> Scalar:
|
72
85
|
return self._array_backend.sqrt(arr.var(axis=axis))
|
73
86
|
|
@@ -80,30 +93,12 @@ class MLXBackend(NumpyFFTWBackend):
|
|
80
93
|
def tobytes(self, arr):
|
81
94
|
return self.to_numpy_array(arr).tobytes()
|
82
95
|
|
83
|
-
def preallocate_array(self, shape: Tuple[int], dtype: type = None) -> NDArray:
|
84
|
-
"""
|
85
|
-
Returns a byte-aligned array of zeros with specified shape and dtype.
|
86
|
-
|
87
|
-
Parameters
|
88
|
-
----------
|
89
|
-
shape : Tuple[int]
|
90
|
-
Desired shape for the array.
|
91
|
-
dtype : type, optional
|
92
|
-
Desired data type for the array.
|
93
|
-
|
94
|
-
Returns
|
95
|
-
-------
|
96
|
-
NDArray
|
97
|
-
Byte-aligned array of zeros with specified shape and dtype.
|
98
|
-
"""
|
99
|
-
arr = self._array_backend.zeros(shape, dtype=dtype)
|
100
|
-
return arr
|
101
|
-
|
102
96
|
def full(self, shape, fill_value, dtype=None):
|
103
97
|
return self._array_backend.full(shape=shape, dtype=dtype, vals=fill_value)
|
104
98
|
|
105
|
-
def fill(self, arr: MlxArray, value: Scalar) ->
|
99
|
+
def fill(self, arr: MlxArray, value: Scalar) -> MlxArray:
|
106
100
|
arr[:] = value
|
101
|
+
return arr
|
107
102
|
|
108
103
|
def zeros(self, shape: Tuple[int], dtype: type = None) -> MlxArray:
|
109
104
|
return self._array_backend.zeros(shape=shape, dtype=dtype)
|
@@ -141,8 +136,8 @@ class MLXBackend(NumpyFFTWBackend):
|
|
141
136
|
new_shape = self.to_backend_array(newshape)
|
142
137
|
current_shape = self.to_backend_array(arr.shape)
|
143
138
|
starts = self.subtract(current_shape, new_shape)
|
144
|
-
starts = self.astype(self.divide(starts, 2), self.
|
145
|
-
stops = self.astype(self.add(starts, newshape), self.
|
139
|
+
starts = self.astype(self.divide(starts, 2), self._int_dtype)
|
140
|
+
stops = self.astype(self.add(starts, newshape), self._int_dtype)
|
146
141
|
starts, stops = starts.tolist(), stops.tolist()
|
147
142
|
box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
|
148
143
|
return arr[box]
|
@@ -185,13 +180,11 @@ class MLXBackend(NumpyFFTWBackend):
|
|
185
180
|
|
186
181
|
return rfftn, irfftn
|
187
182
|
|
188
|
-
def
|
189
|
-
|
190
|
-
) -> MlxArray:
|
191
|
-
return shm
|
183
|
+
def from_sharedarr(self, arr: MlxArray) -> MlxArray:
|
184
|
+
return arr
|
192
185
|
|
193
186
|
@staticmethod
|
194
|
-
def
|
187
|
+
def to_sharedarr(arr: MlxArray, shared_memory_handler: type = None) -> shm_type:
|
195
188
|
return arr
|
196
189
|
|
197
190
|
def topk_indices(self, arr: NDArray, k: int):
|
@@ -200,7 +193,7 @@ class MLXBackend(NumpyFFTWBackend):
|
|
200
193
|
ret = [self.to_backend_array(x) for x in ret]
|
201
194
|
return ret
|
202
195
|
|
203
|
-
def
|
196
|
+
def rigid_transform(
|
204
197
|
self,
|
205
198
|
arr: NDArray,
|
206
199
|
rotation_matrix: NDArray,
|
@@ -210,10 +203,8 @@ class MLXBackend(NumpyFFTWBackend):
|
|
210
203
|
out: NDArray = None,
|
211
204
|
out_mask: NDArray = None,
|
212
205
|
order: int = 3,
|
206
|
+
**kwargs,
|
213
207
|
) -> None:
|
214
|
-
rotate_mask = arr_mask is not None
|
215
|
-
return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
|
216
|
-
|
217
208
|
arr = self.to_numpy_array(arr)
|
218
209
|
rotation_matrix = self.to_numpy_array(rotation_matrix)
|
219
210
|
|
@@ -223,46 +214,32 @@ class MLXBackend(NumpyFFTWBackend):
|
|
223
214
|
if translation is not None:
|
224
215
|
translation = self.to_numpy_array(translation)
|
225
216
|
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
out_mask_pass = self.to_numpy_array(out_mask)
|
217
|
+
if out is None:
|
218
|
+
out = self.zeros(arr.shape)
|
219
|
+
if out_mask is None and arr_mask is not None:
|
220
|
+
out_mask_pass = self.zeros(arr_mask.shape)
|
231
221
|
|
232
|
-
ret = NumpyFFTWBackend().
|
222
|
+
ret = NumpyFFTWBackend().rigid_transform(
|
233
223
|
arr=arr,
|
234
224
|
rotation_matrix=rotation_matrix,
|
235
225
|
arr_mask=arr_mask,
|
236
226
|
translation=translation,
|
237
227
|
use_geometric_center=use_geometric_center,
|
238
|
-
out=out_pass,
|
239
|
-
out_mask=out_mask_pass,
|
240
228
|
order=order,
|
241
229
|
)
|
242
230
|
|
243
|
-
|
244
|
-
|
245
|
-
out_pass = ret
|
246
|
-
elif len(ret) == 1 and out_mask is None:
|
247
|
-
out_mask_pass = ret
|
248
|
-
else:
|
249
|
-
out_pass, out_mask_pass = ret
|
231
|
+
out_pass, out_mask_pass = ret
|
232
|
+
out[:] = self.to_backend_array(out_pass)
|
250
233
|
|
251
|
-
if
|
252
|
-
|
234
|
+
if out_mask_pass is not None:
|
235
|
+
out_mask_pass = self.to_backend_array(out_mask_pass)
|
253
236
|
|
254
237
|
if out_mask is not None:
|
255
|
-
out_mask[:] =
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
case 1:
|
261
|
-
return out
|
262
|
-
case 2:
|
263
|
-
return out_mask
|
264
|
-
case 3:
|
265
|
-
return out, out_mask
|
238
|
+
out_mask[:] = out_mask_pass
|
239
|
+
else:
|
240
|
+
out_mask = out_mask_pass
|
241
|
+
|
242
|
+
return out, out_mask
|
266
243
|
|
267
244
|
def indices(self, arr: List) -> MlxArray:
|
268
245
|
ret = NumpyFFTWBackend().indices(arr)
|