pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
- pytme-0.2.9.data/scripts/match_template.py +1135 -0
- pytme-0.2.9.data/scripts/postprocess.py +622 -0
- pytme-0.2.9.data/scripts/preprocess.py +209 -0
- pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
- pytme-0.2.9.dist-info/METADATA +95 -0
- pytme-0.2.9.dist-info/RECORD +119 -0
- pytme-0.2.9.dist-info/WHEEL +5 -0
- pytme-0.2.9.dist-info/entry_points.txt +6 -0
- pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
- pytme-0.2.9.dist-info/top_level.txt +3 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +1135 -0
- scripts/postprocess.py +622 -0
- scripts/preprocess.py +209 -0
- scripts/preprocessor_gui.py +1227 -0
- tests/__init__.py +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +216 -0
- tests/test_backends.py +446 -0
- tests/test_density.py +503 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +124 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +226 -0
- tests/test_matching_utils.py +189 -0
- tests/test_orientations.py +175 -0
- tests/test_parser.py +33 -0
- tests/test_rotations.py +153 -0
- tests/test_structure.py +247 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +186 -0
- tme/analyzer/aggregation.py +577 -0
- tme/analyzer/peaks.py +953 -0
- tme/backends/__init__.py +171 -0
- tme/backends/_cupy_utils.py +734 -0
- tme/backends/_jax_utils.py +188 -0
- tme/backends/cupy_backend.py +294 -0
- tme/backends/jax_backend.py +314 -0
- tme/backends/matching_backend.py +1270 -0
- tme/backends/mlx_backend.py +241 -0
- tme/backends/npfftw_backend.py +583 -0
- tme/backends/pytorch_backend.py +430 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2263 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/filters/__init__.py +6 -0
- tme/filters/_utils.py +311 -0
- tme/filters/bandpass.py +230 -0
- tme/filters/compose.py +81 -0
- tme/filters/ctf.py +393 -0
- tme/filters/reconstruction.py +160 -0
- tme/filters/wedge.py +542 -0
- tme/filters/whitening.py +191 -0
- tme/matching_data.py +863 -0
- tme/matching_exhaustive.py +497 -0
- tme/matching_optimization.py +1311 -0
- tme/matching_scores.py +1183 -0
- tme/matching_utils.py +1188 -0
- tme/memory.py +337 -0
- tme/orientations.py +598 -0
- tme/parser.py +685 -0
- tme/preprocessor.py +1329 -0
- tme/rotations.py +350 -0
- tme/structure.py +1864 -0
- tme/types.py +13 -0
@@ -0,0 +1,314 @@
|
|
1
|
+
""" Backend using jax for template matching.
|
2
|
+
|
3
|
+
Copyright (c) 2023-2024 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
from functools import wraps
|
9
|
+
from typing import Tuple, List, Callable
|
10
|
+
|
11
|
+
from ..types import BackendArray
|
12
|
+
from .npfftw_backend import NumpyFFTWBackend, shm_type
|
13
|
+
|
14
|
+
|
15
|
+
def emulate_out(func):
|
16
|
+
"""
|
17
|
+
Adds an out argument to write output of ``func`` to.
|
18
|
+
"""
|
19
|
+
|
20
|
+
@wraps(func)
|
21
|
+
def inner(*args, out=None, **kwargs):
|
22
|
+
ret = func(*args, **kwargs)
|
23
|
+
if out is not None:
|
24
|
+
out = out.at[:].set(ret)
|
25
|
+
return out
|
26
|
+
return ret
|
27
|
+
|
28
|
+
return inner
|
29
|
+
|
30
|
+
|
31
|
+
class JaxBackend(NumpyFFTWBackend):
|
32
|
+
"""
|
33
|
+
A jax-based matching backend.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(self, float_dtype=None, complex_dtype=None, int_dtype=None, **kwargs):
|
37
|
+
import jax.scipy as jsp
|
38
|
+
import jax.numpy as jnp
|
39
|
+
|
40
|
+
float_dtype = jnp.float32 if float_dtype is None else float_dtype
|
41
|
+
complex_dtype = jnp.complex64 if complex_dtype is None else complex_dtype
|
42
|
+
int_dtype = jnp.int32 if int_dtype is None else int_dtype
|
43
|
+
|
44
|
+
super().__init__(
|
45
|
+
array_backend=jnp,
|
46
|
+
float_dtype=float_dtype,
|
47
|
+
complex_dtype=complex_dtype,
|
48
|
+
int_dtype=int_dtype,
|
49
|
+
overflow_safe_dtype=float_dtype,
|
50
|
+
)
|
51
|
+
self.scipy = jsp
|
52
|
+
self._create_ufuncs()
|
53
|
+
try:
|
54
|
+
from ._jax_utils import scan as _
|
55
|
+
|
56
|
+
self.scan = self._scan
|
57
|
+
except Exception:
|
58
|
+
pass
|
59
|
+
|
60
|
+
def from_sharedarr(self, arr: BackendArray) -> BackendArray:
|
61
|
+
return arr
|
62
|
+
|
63
|
+
@staticmethod
|
64
|
+
def to_sharedarr(arr: BackendArray, shared_memory_handler: type = None) -> shm_type:
|
65
|
+
return arr
|
66
|
+
|
67
|
+
@staticmethod
|
68
|
+
def at(arr, idx, value) -> BackendArray:
|
69
|
+
arr = arr.at[idx].set(value)
|
70
|
+
return arr
|
71
|
+
|
72
|
+
def topleft_pad(
|
73
|
+
self, arr: BackendArray, shape: Tuple[int], padval: int = 0
|
74
|
+
) -> BackendArray:
|
75
|
+
b = self.full(shape=shape, dtype=arr.dtype, fill_value=padval)
|
76
|
+
aind = [slice(None, None)] * arr.ndim
|
77
|
+
bind = [slice(None, None)] * arr.ndim
|
78
|
+
for i in range(arr.ndim):
|
79
|
+
if arr.shape[i] > shape[i]:
|
80
|
+
aind[i] = slice(0, shape[i])
|
81
|
+
elif arr.shape[i] < shape[i]:
|
82
|
+
bind[i] = slice(0, arr.shape[i])
|
83
|
+
b = b.at[tuple(bind)].set(arr[tuple(aind)])
|
84
|
+
return b
|
85
|
+
|
86
|
+
def _create_ufuncs(self):
|
87
|
+
ufuncs = [
|
88
|
+
"add",
|
89
|
+
"subtract",
|
90
|
+
"multiply",
|
91
|
+
"divide",
|
92
|
+
"square",
|
93
|
+
"sqrt",
|
94
|
+
"maximum",
|
95
|
+
"exp",
|
96
|
+
]
|
97
|
+
for ufunc in ufuncs:
|
98
|
+
backend_method = emulate_out(getattr(self._array_backend, ufunc))
|
99
|
+
setattr(self, ufunc, staticmethod(backend_method))
|
100
|
+
|
101
|
+
ufuncs = ["zeros", "full"]
|
102
|
+
for ufunc in ufuncs:
|
103
|
+
backend_method = getattr(self._array_backend, ufunc)
|
104
|
+
setattr(self, ufunc, staticmethod(backend_method))
|
105
|
+
|
106
|
+
def fill(self, arr: BackendArray, value: float) -> BackendArray:
|
107
|
+
return self._array_backend.full(
|
108
|
+
shape=arr.shape, dtype=arr.dtype, fill_value=value
|
109
|
+
)
|
110
|
+
|
111
|
+
def build_fft(
|
112
|
+
self,
|
113
|
+
fwd_shape: Tuple[int],
|
114
|
+
inv_shape: Tuple[int] = None,
|
115
|
+
inv_output_shape: Tuple[int] = None,
|
116
|
+
fwd_axes: Tuple[int] = None,
|
117
|
+
inv_axes: Tuple[int] = None,
|
118
|
+
**kwargs,
|
119
|
+
) -> Tuple[Callable, Callable]:
|
120
|
+
rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
|
121
|
+
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
122
|
+
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
123
|
+
|
124
|
+
def rfftn(arr, out=None, s=rfft_shape, axes=fwd_axes):
|
125
|
+
return self._array_backend.fft.rfftn(arr, s=s, axes=axes)
|
126
|
+
|
127
|
+
def irfftn(arr, out=None, s=irfft_shape, axes=inv_axes):
|
128
|
+
return self._array_backend.fft.irfftn(arr, s=s, axes=axes)
|
129
|
+
|
130
|
+
return rfftn, irfftn
|
131
|
+
|
132
|
+
def rfftn(self, arr: BackendArray, *args, **kwargs) -> BackendArray:
|
133
|
+
return self._array_backend.fft.rfftn(arr, **kwargs)
|
134
|
+
|
135
|
+
def irfftn(self, arr: BackendArray, *args, **kwargs) -> BackendArray:
|
136
|
+
return self._array_backend.fft.irfftn(arr, **kwargs)
|
137
|
+
|
138
|
+
def rigid_transform(
|
139
|
+
self,
|
140
|
+
arr: BackendArray,
|
141
|
+
rotation_matrix: BackendArray,
|
142
|
+
out: BackendArray = None,
|
143
|
+
out_mask: BackendArray = None,
|
144
|
+
translation: BackendArray = None,
|
145
|
+
arr_mask: BackendArray = None,
|
146
|
+
order: int = 1,
|
147
|
+
**kwargs,
|
148
|
+
) -> Tuple[BackendArray, BackendArray]:
|
149
|
+
rotate_mask = arr_mask is not None
|
150
|
+
|
151
|
+
# This approach is only valid for order <= 1
|
152
|
+
if arr.ndim != rotation_matrix.shape[0]:
|
153
|
+
matrix = self._array_backend.zeros((arr.ndim, arr.ndim))
|
154
|
+
matrix = matrix.at[0, 0].set(1)
|
155
|
+
matrix = matrix.at[1:, 1:].add(rotation_matrix)
|
156
|
+
rotation_matrix = matrix
|
157
|
+
|
158
|
+
center = self.divide(self.to_backend_array(arr.shape) - 1, 2)[:, None]
|
159
|
+
indices = self._array_backend.indices(arr.shape, dtype=self._float_dtype)
|
160
|
+
indices = indices.reshape((arr.ndim, -1))
|
161
|
+
indices = indices.at[:].add(-center)
|
162
|
+
indices = self._array_backend.matmul(rotation_matrix.T, indices)
|
163
|
+
indices = indices.at[:].add(center)
|
164
|
+
if translation is not None:
|
165
|
+
indices = indices.at[:].add(translation)
|
166
|
+
|
167
|
+
out = self.scipy.ndimage.map_coordinates(arr, indices, order=order).reshape(
|
168
|
+
arr.shape
|
169
|
+
)
|
170
|
+
|
171
|
+
out_mask = arr_mask
|
172
|
+
if rotate_mask:
|
173
|
+
out_mask = self.scipy.ndimage.map_coordinates(
|
174
|
+
arr_mask, indices, order=order
|
175
|
+
).reshape(arr_mask.shape)
|
176
|
+
|
177
|
+
return out, out_mask
|
178
|
+
|
179
|
+
def max_score_over_rotations(
|
180
|
+
self,
|
181
|
+
scores: BackendArray,
|
182
|
+
max_scores: BackendArray,
|
183
|
+
rotations: BackendArray,
|
184
|
+
rotation_index: int,
|
185
|
+
) -> Tuple[BackendArray, BackendArray]:
|
186
|
+
update = self.greater(max_scores, scores)
|
187
|
+
max_scores = max_scores.at[:].set(self.where(update, max_scores, scores))
|
188
|
+
rotations = rotations.at[:].set(self.where(update, rotations, rotation_index))
|
189
|
+
return max_scores, rotations
|
190
|
+
|
191
|
+
def _scan(
|
192
|
+
self,
|
193
|
+
matching_data: type,
|
194
|
+
splits: Tuple[Tuple[slice, slice]],
|
195
|
+
n_jobs: int,
|
196
|
+
callback_class,
|
197
|
+
rotate_mask: bool = False,
|
198
|
+
**kwargs,
|
199
|
+
) -> List:
|
200
|
+
"""
|
201
|
+
Emulates output of :py:meth:`tme.matching_exhaustive.scan` using
|
202
|
+
:py:class:`tme.analyzer.MaxScoreOverRotations`.
|
203
|
+
"""
|
204
|
+
from ._jax_utils import scan as scan_inner
|
205
|
+
|
206
|
+
pad_target = True if len(splits) > 1 else False
|
207
|
+
convolution_mode = "valid" if pad_target else "same"
|
208
|
+
target_pad = matching_data.target_padding(pad_target=pad_target)
|
209
|
+
|
210
|
+
target_shape = tuple(
|
211
|
+
(x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad)
|
212
|
+
)
|
213
|
+
conv_shape, fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
|
214
|
+
target_shape=self.to_numpy_array(target_shape),
|
215
|
+
template_shape=self.to_numpy_array(matching_data._template.shape),
|
216
|
+
pad_fourier=False,
|
217
|
+
)
|
218
|
+
|
219
|
+
analyzer_args = {
|
220
|
+
"convolution_mode": convolution_mode,
|
221
|
+
"fourier_shift": shift,
|
222
|
+
"targetshape": target_shape,
|
223
|
+
"templateshape": matching_data.template.shape,
|
224
|
+
"convolution_shape": conv_shape,
|
225
|
+
}
|
226
|
+
|
227
|
+
create_target_filter = matching_data.target_filter is not None
|
228
|
+
create_template_filter = matching_data.template_filter is not None
|
229
|
+
create_filter = create_target_filter or create_template_filter
|
230
|
+
|
231
|
+
# Applying the filter leads to more FFTs
|
232
|
+
fastt_shape = matching_data._template.shape
|
233
|
+
if create_template_filter:
|
234
|
+
fastt_shape = matching_data._template.shape
|
235
|
+
|
236
|
+
ret, template_filter, target_filter = [], 1, 1
|
237
|
+
rotation_mapping = {
|
238
|
+
self.tobytes(matching_data.rotations[i]): i
|
239
|
+
for i in range(matching_data.rotations.shape[0])
|
240
|
+
}
|
241
|
+
for split_start in range(0, len(splits), n_jobs):
|
242
|
+
split_subset = splits[split_start : (split_start + n_jobs)]
|
243
|
+
if not len(split_subset):
|
244
|
+
continue
|
245
|
+
|
246
|
+
targets, translation_offsets = [], []
|
247
|
+
for target_split, template_split in split_subset:
|
248
|
+
base = matching_data.subset_by_slice(
|
249
|
+
target_slice=target_split,
|
250
|
+
target_pad=target_pad,
|
251
|
+
template_slice=template_split,
|
252
|
+
)
|
253
|
+
translation_offsets.append(base._translation_offset)
|
254
|
+
targets.append(self.topleft_pad(base._target, fast_shape))
|
255
|
+
|
256
|
+
if create_filter:
|
257
|
+
filter_args = {
|
258
|
+
"data_rfft": self.fft.rfftn(targets[0]),
|
259
|
+
"return_real_fourier": True,
|
260
|
+
"shape_is_real_fourier": False,
|
261
|
+
}
|
262
|
+
|
263
|
+
if create_template_filter:
|
264
|
+
template_filter = matching_data.template_filter(
|
265
|
+
shape=fastt_shape, **filter_args
|
266
|
+
)["data"]
|
267
|
+
template_filter = template_filter.at[(0,) * template_filter.ndim].set(0)
|
268
|
+
|
269
|
+
if create_target_filter:
|
270
|
+
target_filter = matching_data.target_filter(
|
271
|
+
shape=fast_shape, **filter_args
|
272
|
+
)["data"]
|
273
|
+
target_filter = target_filter.at[(0,) * target_filter.ndim].set(0)
|
274
|
+
|
275
|
+
create_filter, create_template_filter, create_target_filter = (False,) * 3
|
276
|
+
base, targets = None, self._array_backend.stack(targets)
|
277
|
+
scores, rotations = scan_inner(
|
278
|
+
self.astype(targets, self._float_dtype),
|
279
|
+
matching_data.template,
|
280
|
+
matching_data.template_mask,
|
281
|
+
matching_data.rotations,
|
282
|
+
template_filter,
|
283
|
+
target_filter,
|
284
|
+
fast_shape,
|
285
|
+
rotate_mask,
|
286
|
+
)
|
287
|
+
|
288
|
+
for index in range(scores.shape[0]):
|
289
|
+
temp = callback_class(
|
290
|
+
shape=scores.shape,
|
291
|
+
scores=scores[index],
|
292
|
+
rotations=rotations[index],
|
293
|
+
thread_safe=False,
|
294
|
+
offset=translation_offsets[index],
|
295
|
+
)
|
296
|
+
temp.rotation_mapping = rotation_mapping
|
297
|
+
ret.append(tuple(temp._postprocess(**analyzer_args)))
|
298
|
+
|
299
|
+
return ret
|
300
|
+
|
301
|
+
def get_available_memory(self) -> int:
|
302
|
+
import jax
|
303
|
+
|
304
|
+
_memory = {"cpu": 0, "gpu": 0}
|
305
|
+
for device in jax.devices():
|
306
|
+
if device.platform == "cpu":
|
307
|
+
_memory["cpu"] = super().get_available_memory()
|
308
|
+
else:
|
309
|
+
mem_stats = device.memory_stats()
|
310
|
+
_memory["gpu"] += mem_stats.get("bytes_limit", 0)
|
311
|
+
|
312
|
+
if _memory["gpu"] > 0:
|
313
|
+
return _memory["gpu"]
|
314
|
+
return _memory["cpu"]
|