pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
- pytme-0.2.9.data/scripts/match_template.py +1135 -0
- pytme-0.2.9.data/scripts/postprocess.py +622 -0
- pytme-0.2.9.data/scripts/preprocess.py +209 -0
- pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
- pytme-0.2.9.dist-info/METADATA +95 -0
- pytme-0.2.9.dist-info/RECORD +119 -0
- pytme-0.2.9.dist-info/WHEEL +5 -0
- pytme-0.2.9.dist-info/entry_points.txt +6 -0
- pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
- pytme-0.2.9.dist-info/top_level.txt +3 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +1135 -0
- scripts/postprocess.py +622 -0
- scripts/preprocess.py +209 -0
- scripts/preprocessor_gui.py +1227 -0
- tests/__init__.py +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +216 -0
- tests/test_backends.py +446 -0
- tests/test_density.py +503 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +124 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +226 -0
- tests/test_matching_utils.py +189 -0
- tests/test_orientations.py +175 -0
- tests/test_parser.py +33 -0
- tests/test_rotations.py +153 -0
- tests/test_structure.py +247 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +186 -0
- tme/analyzer/aggregation.py +577 -0
- tme/analyzer/peaks.py +953 -0
- tme/backends/__init__.py +171 -0
- tme/backends/_cupy_utils.py +734 -0
- tme/backends/_jax_utils.py +188 -0
- tme/backends/cupy_backend.py +294 -0
- tme/backends/jax_backend.py +314 -0
- tme/backends/matching_backend.py +1270 -0
- tme/backends/mlx_backend.py +241 -0
- tme/backends/npfftw_backend.py +583 -0
- tme/backends/pytorch_backend.py +430 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2263 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/filters/__init__.py +6 -0
- tme/filters/_utils.py +311 -0
- tme/filters/bandpass.py +230 -0
- tme/filters/compose.py +81 -0
- tme/filters/ctf.py +393 -0
- tme/filters/reconstruction.py +160 -0
- tme/filters/wedge.py +542 -0
- tme/filters/whitening.py +191 -0
- tme/matching_data.py +863 -0
- tme/matching_exhaustive.py +497 -0
- tme/matching_optimization.py +1311 -0
- tme/matching_scores.py +1183 -0
- tme/matching_utils.py +1188 -0
- tme/memory.py +337 -0
- tme/orientations.py +598 -0
- tme/parser.py +685 -0
- tme/preprocessor.py +1329 -0
- tme/rotations.py +350 -0
- tme/structure.py +1864 -0
- tme/types.py +13 -0
@@ -0,0 +1,241 @@
|
|
1
|
+
""" Backend using Apple's MLX library for template matching.
|
2
|
+
|
3
|
+
Copyright (c) 2024 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import Tuple, List, Callable
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
|
12
|
+
from .npfftw_backend import NumpyFFTWBackend
|
13
|
+
from ..types import NDArray, MlxArray, Scalar, shm_type
|
14
|
+
|
15
|
+
|
16
|
+
class MLXBackend(NumpyFFTWBackend):
|
17
|
+
"""
|
18
|
+
A mlx-based matching backend.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
device="cpu",
|
24
|
+
float_dtype=None,
|
25
|
+
complex_dtype=None,
|
26
|
+
int_dtype=None,
|
27
|
+
overflow_safe_dtype=None,
|
28
|
+
**kwargs,
|
29
|
+
):
|
30
|
+
import mlx.core as mx
|
31
|
+
|
32
|
+
device = mx.cpu if device == "cpu" else mx.gpu
|
33
|
+
float_dtype = mx.float32 if float_dtype is None else float_dtype
|
34
|
+
complex_dtype = mx.complex64 if complex_dtype is None else complex_dtype
|
35
|
+
int_dtype = mx.int32 if int_dtype is None else int_dtype
|
36
|
+
if overflow_safe_dtype is None:
|
37
|
+
overflow_safe_dtype = mx.float32
|
38
|
+
|
39
|
+
super().__init__(
|
40
|
+
array_backend=mx,
|
41
|
+
float_dtype=float_dtype,
|
42
|
+
complex_dtype=complex_dtype,
|
43
|
+
int_dtype=int_dtype,
|
44
|
+
overflow_safe_dtype=overflow_safe_dtype,
|
45
|
+
)
|
46
|
+
|
47
|
+
self.device = device
|
48
|
+
|
49
|
+
def to_backend_array(self, arr: NDArray) -> MlxArray:
|
50
|
+
return self._array_backend.array(arr)
|
51
|
+
|
52
|
+
def to_numpy_array(self, arr: MlxArray) -> NDArray:
|
53
|
+
return np.array(arr)
|
54
|
+
|
55
|
+
def to_cpu_array(self, arr: MlxArray) -> NDArray:
|
56
|
+
return arr
|
57
|
+
|
58
|
+
def free_cache(self):
|
59
|
+
pass
|
60
|
+
|
61
|
+
def mod(self, arr1: MlxArray, arr2: MlxArray, out: MlxArray = None) -> MlxArray:
|
62
|
+
if out is not None:
|
63
|
+
out[:] = arr1 % arr2
|
64
|
+
return None
|
65
|
+
return arr1 % arr2
|
66
|
+
|
67
|
+
def add(self, x1, x2, out: MlxArray = None, **kwargs) -> MlxArray:
|
68
|
+
x1 = self.to_backend_array(x1)
|
69
|
+
x2 = self.to_backend_array(x2)
|
70
|
+
|
71
|
+
if out is not None:
|
72
|
+
out[:] = self._array_backend.add(x1, x2, **kwargs)
|
73
|
+
return None
|
74
|
+
return self._array_backend.add(x1, x2, **kwargs)
|
75
|
+
|
76
|
+
def multiply(self, x1, x2, out: MlxArray = None, **kwargs) -> MlxArray:
|
77
|
+
x1 = self.to_backend_array(x1)
|
78
|
+
x2 = self.to_backend_array(x2)
|
79
|
+
|
80
|
+
if out is not None:
|
81
|
+
out[:] = self._array_backend.multiply(x1, x2, **kwargs)
|
82
|
+
return None
|
83
|
+
return self._array_backend.multiply(x1, x2, **kwargs)
|
84
|
+
|
85
|
+
def std(self, arr: MlxArray, axis) -> Scalar:
|
86
|
+
return self._array_backend.sqrt(arr.var(axis=axis))
|
87
|
+
|
88
|
+
def unique(self, *args, **kwargs):
|
89
|
+
ret = np.unique(*args, **kwargs)
|
90
|
+
if isinstance(ret, tuple):
|
91
|
+
ret = [self.to_backend_array(x) for x in ret]
|
92
|
+
return ret
|
93
|
+
|
94
|
+
def tobytes(self, arr):
|
95
|
+
return self.to_numpy_array(arr).tobytes()
|
96
|
+
|
97
|
+
def full(self, shape, fill_value, dtype=None):
|
98
|
+
return self._array_backend.full(shape=shape, dtype=dtype, vals=fill_value)
|
99
|
+
|
100
|
+
def fill(self, arr: MlxArray, value: Scalar) -> MlxArray:
|
101
|
+
arr[:] = value
|
102
|
+
return arr
|
103
|
+
|
104
|
+
def zeros(self, shape: Tuple[int], dtype: type = None) -> MlxArray:
|
105
|
+
return self._array_backend.zeros(shape=shape, dtype=dtype)
|
106
|
+
|
107
|
+
def roll(self, a: MlxArray, shift, axis, **kwargs):
|
108
|
+
a = self.to_numpy_array(a)
|
109
|
+
ret = NumpyFFTWBackend().roll(
|
110
|
+
a,
|
111
|
+
shift=shift,
|
112
|
+
axis=axis,
|
113
|
+
**kwargs,
|
114
|
+
)
|
115
|
+
return self.to_backend_array(ret)
|
116
|
+
|
117
|
+
def extract_center(self, arr: NDArray, newshape: Tuple[int]) -> NDArray:
|
118
|
+
"""
|
119
|
+
Extract the centered portion of an array based on a new shape.
|
120
|
+
|
121
|
+
Parameters
|
122
|
+
----------
|
123
|
+
arr : NDArray
|
124
|
+
Input array.
|
125
|
+
newshape : tuple
|
126
|
+
Desired shape for the central portion.
|
127
|
+
|
128
|
+
Returns
|
129
|
+
-------
|
130
|
+
NDArray
|
131
|
+
Central portion of the array with shape `newshape`.
|
132
|
+
|
133
|
+
References
|
134
|
+
----------
|
135
|
+
.. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py
|
136
|
+
"""
|
137
|
+
new_shape = self.to_backend_array(newshape)
|
138
|
+
current_shape = self.to_backend_array(arr.shape)
|
139
|
+
starts = self.subtract(current_shape, new_shape)
|
140
|
+
starts = self.astype(self.divide(starts, 2), self._int_dtype)
|
141
|
+
stops = self.astype(self.add(starts, newshape), self._int_dtype)
|
142
|
+
starts, stops = starts.tolist(), stops.tolist()
|
143
|
+
box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
|
144
|
+
return arr[box]
|
145
|
+
|
146
|
+
def build_fft(
|
147
|
+
self,
|
148
|
+
fwd_shape: Tuple[int],
|
149
|
+
inv_shape: Tuple[int] = None,
|
150
|
+
inv_output_shape: Tuple[int] = None,
|
151
|
+
fwd_axes: Tuple[int] = None,
|
152
|
+
inv_axes: Tuple[int] = None,
|
153
|
+
**kwargs,
|
154
|
+
) -> Tuple[Callable, Callable]:
|
155
|
+
# Runs on mlx.core.cpu until Metal support is available
|
156
|
+
rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
|
157
|
+
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
158
|
+
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
159
|
+
|
160
|
+
def rfftn(arr: MlxArray, out: MlxArray = None, s=rfft_shape, axes=fwd_axes):
|
161
|
+
out[:] = self._array_backend.fft.rfftn(
|
162
|
+
arr, s=s, axes=axes, stream=self._array_backend.cpu
|
163
|
+
)
|
164
|
+
|
165
|
+
def irfftn(arr: MlxArray, out: MlxArray = None, s=irfft_shape, axes=inv_axes):
|
166
|
+
out[:] = self._array_backend.fft.irfftn(
|
167
|
+
arr, s=s, axes=axes, stream=self._array_backend.cpu
|
168
|
+
)
|
169
|
+
|
170
|
+
return rfftn, irfftn
|
171
|
+
|
172
|
+
def rfftn(self, arr, *args, **kwargs):
|
173
|
+
return self.fft.rfftn(arr, stream=self._array_backend.cpu, **kwargs)
|
174
|
+
|
175
|
+
def irfftn(self, arr, *args, **kwargs):
|
176
|
+
return self.fft.irfftn(arr, stream=self._array_backend.cpu, **kwargs)
|
177
|
+
|
178
|
+
def from_sharedarr(self, arr: MlxArray) -> MlxArray:
|
179
|
+
return arr
|
180
|
+
|
181
|
+
@staticmethod
|
182
|
+
def to_sharedarr(arr: MlxArray, shared_memory_handler: type = None) -> shm_type:
|
183
|
+
return arr
|
184
|
+
|
185
|
+
def topk_indices(self, arr: NDArray, k: int):
|
186
|
+
arr = self.to_numpy_array(arr)
|
187
|
+
ret = NumpyFFTWBackend().topk_indices(arr=arr, k=k)
|
188
|
+
ret = [self.to_backend_array(x) for x in ret]
|
189
|
+
return ret
|
190
|
+
|
191
|
+
def rigid_transform(
|
192
|
+
self,
|
193
|
+
arr: NDArray,
|
194
|
+
rotation_matrix: NDArray,
|
195
|
+
arr_mask: NDArray = None,
|
196
|
+
translation: NDArray = None,
|
197
|
+
use_geometric_center: bool = False,
|
198
|
+
out: NDArray = None,
|
199
|
+
out_mask: NDArray = None,
|
200
|
+
order: int = 3,
|
201
|
+
**kwargs,
|
202
|
+
) -> None:
|
203
|
+
arr = self.to_numpy_array(arr)
|
204
|
+
rotation_matrix = self.to_numpy_array(rotation_matrix)
|
205
|
+
|
206
|
+
if arr_mask is not None:
|
207
|
+
arr_mask = self.to_numpy_array(arr_mask)
|
208
|
+
|
209
|
+
if translation is not None:
|
210
|
+
translation = self.to_numpy_array(translation)
|
211
|
+
|
212
|
+
if out is None:
|
213
|
+
out = self.zeros(arr.shape)
|
214
|
+
if out_mask is None and arr_mask is not None:
|
215
|
+
out_mask_pass = self.zeros(arr_mask.shape)
|
216
|
+
|
217
|
+
ret = NumpyFFTWBackend().rigid_transform(
|
218
|
+
arr=arr,
|
219
|
+
rotation_matrix=rotation_matrix,
|
220
|
+
arr_mask=arr_mask,
|
221
|
+
translation=translation,
|
222
|
+
use_geometric_center=use_geometric_center,
|
223
|
+
order=order,
|
224
|
+
)
|
225
|
+
|
226
|
+
out_pass, out_mask_pass = ret
|
227
|
+
out[:] = self.to_backend_array(out_pass)
|
228
|
+
|
229
|
+
if out_mask_pass is not None:
|
230
|
+
out_mask_pass = self.to_backend_array(out_mask_pass)
|
231
|
+
|
232
|
+
if out_mask is not None:
|
233
|
+
out_mask[:] = out_mask_pass
|
234
|
+
else:
|
235
|
+
out_mask = out_mask_pass
|
236
|
+
|
237
|
+
return out, out_mask
|
238
|
+
|
239
|
+
def indices(self, arr: List) -> MlxArray:
|
240
|
+
ret = NumpyFFTWBackend().indices(arr)
|
241
|
+
return self.to_backend_array(ret)
|