pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/match_template.py +147 -93
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/postprocess.py +67 -26
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +175 -85
- pytme-0.2.2.dist-info/METADATA +91 -0
- pytme-0.2.2.dist-info/RECORD +74 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
- scripts/extract_candidates.py +20 -13
- scripts/match_template.py +147 -93
- scripts/match_template_filters.py +154 -95
- scripts/postprocess.py +67 -26
- scripts/preprocessor_gui.py +175 -85
- scripts/refine_matches.py +265 -61
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +451 -809
- tme/backends/__init__.py +40 -11
- tme/backends/_jax_utils.py +185 -0
- tme/backends/cupy_backend.py +111 -223
- tme/backends/jax_backend.py +214 -150
- tme/backends/matching_backend.py +445 -384
- tme/backends/mlx_backend.py +32 -59
- tme/backends/npfftw_backend.py +239 -507
- tme/backends/pytorch_backend.py +21 -145
- tme/density.py +233 -363
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +322 -285
- tme/matching_exhaustive.py +172 -1493
- tme/matching_optimization.py +143 -106
- tme/matching_scores.py +884 -0
- tme/matching_utils.py +280 -386
- tme/memory.py +377 -0
- tme/orientations.py +52 -12
- tme/parser.py +3 -4
- tme/preprocessing/_utils.py +61 -32
- tme/preprocessing/compose.py +7 -3
- tme/preprocessing/frequency_filters.py +49 -39
- tme/preprocessing/tilt_series.py +34 -40
- tme/preprocessor.py +560 -526
- tme/structure.py +491 -188
- tme/types.py +5 -3
- pytme-0.2.1.dist-info/METADATA +0 -73
- pytme-0.2.1.dist-info/RECORD +0 -73
- tme/helpers.py +0 -881
- tme/matching_constrained.py +0 -195
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
tme/memory.py
ADDED
@@ -0,0 +1,377 @@
|
|
1
|
+
""" Compute memory consumption of template matching components.
|
2
|
+
|
3
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
from abc import ABC, abstractmethod
|
9
|
+
from typing import Tuple
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
from pyfftw import next_fast_len
|
13
|
+
|
14
|
+
|
15
|
+
class MatchingMemoryUsage(ABC):
|
16
|
+
"""
|
17
|
+
Class specification for estimating the memory requirements of template matching.
|
18
|
+
|
19
|
+
Parameters
|
20
|
+
----------
|
21
|
+
fast_shape : tuple of int
|
22
|
+
Shape of the real array.
|
23
|
+
ft_shape : tuple of int
|
24
|
+
Shape of the complex array.
|
25
|
+
float_nbytes : int
|
26
|
+
Number of bytes of the used float, e.g. 4 for float32.
|
27
|
+
complex_nbytes : int
|
28
|
+
Number of bytes of the used complex, e.g. 8 for complex64.
|
29
|
+
integer_nbytes : int
|
30
|
+
Number of bytes of the used integer, e.g. 4 for int32.
|
31
|
+
|
32
|
+
Attributes
|
33
|
+
----------
|
34
|
+
real_array_size : int
|
35
|
+
Number of elements in real array.
|
36
|
+
complex_array_size : int
|
37
|
+
Number of elements in complex array.
|
38
|
+
float_nbytes : int
|
39
|
+
Number of bytes of the used float, e.g. 4 for float32.
|
40
|
+
complex_nbytes : int
|
41
|
+
Number of bytes of the used complex, e.g. 8 for complex64.
|
42
|
+
integer_nbytes : int
|
43
|
+
Number of bytes of the used integer, e.g. 4 for int32.
|
44
|
+
|
45
|
+
Methods
|
46
|
+
-------
|
47
|
+
base_usage():
|
48
|
+
Returns the base memory usage in bytes.
|
49
|
+
per_fork():
|
50
|
+
Returns the memory usage in bytes per fork.
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
fast_shape: Tuple[int],
|
56
|
+
ft_shape: Tuple[int],
|
57
|
+
float_nbytes: int,
|
58
|
+
complex_nbytes: int,
|
59
|
+
integer_nbytes: int,
|
60
|
+
):
|
61
|
+
self.real_array_size = np.prod(fast_shape)
|
62
|
+
self.complex_array_size = np.prod(ft_shape)
|
63
|
+
self.float_nbytes = float_nbytes
|
64
|
+
self.complex_nbytes = complex_nbytes
|
65
|
+
self.integer_nbytes = integer_nbytes
|
66
|
+
|
67
|
+
@abstractmethod
|
68
|
+
def base_usage(self) -> int:
|
69
|
+
"""Return the base memory usage in bytes."""
|
70
|
+
|
71
|
+
@abstractmethod
|
72
|
+
def per_fork(self) -> int:
|
73
|
+
"""Return the memory usage per fork in bytes."""
|
74
|
+
|
75
|
+
|
76
|
+
class CCMemoryUsage(MatchingMemoryUsage):
|
77
|
+
"""
|
78
|
+
Memory usage estimation for CC scoring.
|
79
|
+
|
80
|
+
See Also
|
81
|
+
--------
|
82
|
+
:py:meth:`tme.matching_exhaustive.cc_setup`.
|
83
|
+
"""
|
84
|
+
|
85
|
+
def base_usage(self) -> int:
|
86
|
+
float_arrays = self.real_array_size * self.float_nbytes
|
87
|
+
complex_arrays = self.complex_array_size * self.complex_nbytes
|
88
|
+
return float_arrays + complex_arrays
|
89
|
+
|
90
|
+
def per_fork(self) -> int:
|
91
|
+
float_arrays = self.real_array_size * self.float_nbytes
|
92
|
+
complex_arrays = self.complex_array_size * self.complex_nbytes
|
93
|
+
return float_arrays + complex_arrays
|
94
|
+
|
95
|
+
|
96
|
+
class LCCMemoryUsage(CCMemoryUsage):
|
97
|
+
"""
|
98
|
+
Memory usage estimation for LCC scoring.
|
99
|
+
See Also
|
100
|
+
--------
|
101
|
+
:py:meth:`tme.matching_exhaustive.lcc_setup`.
|
102
|
+
"""
|
103
|
+
|
104
|
+
|
105
|
+
class CORRMemoryUsage(MatchingMemoryUsage):
|
106
|
+
"""
|
107
|
+
Memory usage estimation for CORR scoring.
|
108
|
+
|
109
|
+
See Also
|
110
|
+
--------
|
111
|
+
:py:meth:`tme.matching_exhaustive.corr_setup`.
|
112
|
+
"""
|
113
|
+
|
114
|
+
def base_usage(self) -> int:
|
115
|
+
float_arrays = self.real_array_size * self.float_nbytes * 4
|
116
|
+
complex_arrays = self.complex_array_size * self.complex_nbytes
|
117
|
+
return float_arrays + complex_arrays
|
118
|
+
|
119
|
+
def per_fork(self) -> int:
|
120
|
+
float_arrays = self.real_array_size * self.float_nbytes
|
121
|
+
complex_arrays = self.complex_array_size * self.complex_nbytes
|
122
|
+
return float_arrays + complex_arrays
|
123
|
+
|
124
|
+
|
125
|
+
class CAMMemoryUsage(CORRMemoryUsage):
|
126
|
+
"""
|
127
|
+
Memory usage estimation for CAM scoring.
|
128
|
+
|
129
|
+
See Also
|
130
|
+
--------
|
131
|
+
:py:meth:`tme.matching_exhaustive.cam_setup`.
|
132
|
+
"""
|
133
|
+
|
134
|
+
|
135
|
+
class FLCSphericalMaskMemoryUsage(CORRMemoryUsage):
|
136
|
+
"""
|
137
|
+
Memory usage estimation for FLCMSphericalMask scoring.
|
138
|
+
|
139
|
+
See Also
|
140
|
+
--------
|
141
|
+
:py:meth:`tme.matching_exhaustive.flcSphericalMask_setup`.
|
142
|
+
"""
|
143
|
+
|
144
|
+
|
145
|
+
class FLCMemoryUsage(MatchingMemoryUsage):
|
146
|
+
"""
|
147
|
+
Memory usage estimation for FLC scoring.
|
148
|
+
|
149
|
+
See Also
|
150
|
+
--------
|
151
|
+
:py:meth:`tme.matching_exhaustive.flc_setup`.
|
152
|
+
"""
|
153
|
+
|
154
|
+
def base_usage(self) -> int:
|
155
|
+
float_arrays = self.real_array_size * self.float_nbytes * 2
|
156
|
+
complex_arrays = self.complex_array_size * self.complex_nbytes * 2
|
157
|
+
return float_arrays + complex_arrays
|
158
|
+
|
159
|
+
def per_fork(self) -> int:
|
160
|
+
float_arrays = self.real_array_size * self.float_nbytes * 3
|
161
|
+
complex_arrays = self.complex_array_size * self.complex_nbytes * 2
|
162
|
+
return float_arrays + complex_arrays
|
163
|
+
|
164
|
+
|
165
|
+
class MCCMemoryUsage(MatchingMemoryUsage):
|
166
|
+
"""
|
167
|
+
Memory usage estimation for MCC scoring.
|
168
|
+
|
169
|
+
See Also
|
170
|
+
--------
|
171
|
+
:py:meth:`tme.matching_exhaustive.mcc_setup`.
|
172
|
+
"""
|
173
|
+
|
174
|
+
def base_usage(self) -> int:
|
175
|
+
float_arrays = self.real_array_size * self.float_nbytes * 2
|
176
|
+
complex_arrays = self.complex_array_size * self.complex_nbytes * 3
|
177
|
+
return float_arrays + complex_arrays
|
178
|
+
|
179
|
+
def per_fork(self) -> int:
|
180
|
+
float_arrays = self.real_array_size * self.float_nbytes * 6
|
181
|
+
complex_arrays = self.complex_array_size * self.complex_nbytes
|
182
|
+
return float_arrays + complex_arrays
|
183
|
+
|
184
|
+
|
185
|
+
class MaxScoreOverRotationsMemoryUsage(MatchingMemoryUsage):
|
186
|
+
"""
|
187
|
+
Memory usage estimation MaxScoreOverRotations Analyzer.
|
188
|
+
|
189
|
+
See Also
|
190
|
+
--------
|
191
|
+
:py:class:`tme.analyzer.MaxScoreOverRotations`.
|
192
|
+
"""
|
193
|
+
|
194
|
+
def base_usage(self) -> int:
|
195
|
+
float_arrays = self.real_array_size * self.float_nbytes * 2
|
196
|
+
return float_arrays
|
197
|
+
|
198
|
+
def per_fork(self) -> int:
|
199
|
+
return 0
|
200
|
+
|
201
|
+
|
202
|
+
class PeakCallerMaximumFilterMemoryUsage(MatchingMemoryUsage):
|
203
|
+
"""
|
204
|
+
Memory usage estimation MaxScoreOverRotations Analyzer.
|
205
|
+
|
206
|
+
See Also
|
207
|
+
--------
|
208
|
+
:py:class:`tme.analyzer.PeakCallerMaximumFilter`.
|
209
|
+
"""
|
210
|
+
|
211
|
+
def base_usage(self) -> int:
|
212
|
+
float_arrays = self.real_array_size * self.float_nbytes
|
213
|
+
return float_arrays
|
214
|
+
|
215
|
+
def per_fork(self) -> int:
|
216
|
+
float_arrays = self.real_array_size * self.float_nbytes
|
217
|
+
return float_arrays
|
218
|
+
|
219
|
+
|
220
|
+
class CupyBackendMemoryUsage(MatchingMemoryUsage):
|
221
|
+
"""
|
222
|
+
Memory usage estimation for CupyBackend.
|
223
|
+
|
224
|
+
See Also
|
225
|
+
--------
|
226
|
+
:py:class:`tme.backends.CupyBackend`.
|
227
|
+
"""
|
228
|
+
|
229
|
+
def base_usage(self) -> int:
|
230
|
+
# FFT plans, overhead from assigning FFT result, rotation interpolation
|
231
|
+
complex_arrays = self.real_array_size * self.complex_nbytes * 3
|
232
|
+
float_arrays = self.complex_array_size * self.float_nbytes * 2
|
233
|
+
return float_arrays + complex_arrays
|
234
|
+
|
235
|
+
def per_fork(self) -> int:
|
236
|
+
return 0
|
237
|
+
|
238
|
+
|
239
|
+
def _compute_convolution_shapes(
|
240
|
+
arr1_shape: Tuple[int], arr2_shape: Tuple[int]
|
241
|
+
) -> Tuple[Tuple[int], Tuple[int], Tuple[int]]:
|
242
|
+
"""
|
243
|
+
Computes regular, optimized and fourier convolution shape.
|
244
|
+
|
245
|
+
Parameters
|
246
|
+
----------
|
247
|
+
arr1_shape : tuple
|
248
|
+
Tuple of integers corresponding to array1 shape.
|
249
|
+
arr2_shape : tuple
|
250
|
+
Tuple of integers corresponding to array2 shape.
|
251
|
+
|
252
|
+
Returns
|
253
|
+
-------
|
254
|
+
tuple
|
255
|
+
Tuple with regular convolution shape, convolution shape optimized for faster
|
256
|
+
fourier transform, shape of the forward fourier transform
|
257
|
+
(see :py:meth:`build_fft`).
|
258
|
+
"""
|
259
|
+
convolution_shape = np.add(arr1_shape, arr2_shape) - 1
|
260
|
+
fast_shape = [next_fast_len(x) for x in convolution_shape]
|
261
|
+
fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
|
262
|
+
|
263
|
+
return convolution_shape, fast_shape, fast_ft_shape
|
264
|
+
|
265
|
+
|
266
|
+
MATCHING_MEMORY_REGISTRY = {
|
267
|
+
"CC": CCMemoryUsage,
|
268
|
+
"LCC": LCCMemoryUsage,
|
269
|
+
"CORR": CORRMemoryUsage,
|
270
|
+
"CAM": CAMMemoryUsage,
|
271
|
+
"MCC": MCCMemoryUsage,
|
272
|
+
"FLCSphericalMask": FLCSphericalMaskMemoryUsage,
|
273
|
+
"FLC": FLCMemoryUsage,
|
274
|
+
"MaxScoreOverRotations": MaxScoreOverRotationsMemoryUsage,
|
275
|
+
"PeakCallerMaximumFilter": PeakCallerMaximumFilterMemoryUsage,
|
276
|
+
"cupy": CupyBackendMemoryUsage,
|
277
|
+
"pytorch": CupyBackendMemoryUsage,
|
278
|
+
}
|
279
|
+
|
280
|
+
|
281
|
+
def estimate_ram_usage(
|
282
|
+
shape1: Tuple[int],
|
283
|
+
shape2: Tuple[int],
|
284
|
+
matching_method: str,
|
285
|
+
ncores: int,
|
286
|
+
analyzer_method: str = None,
|
287
|
+
backend: str = None,
|
288
|
+
float_nbytes: int = 4,
|
289
|
+
complex_nbytes: int = 8,
|
290
|
+
integer_nbytes: int = 4,
|
291
|
+
) -> int:
|
292
|
+
"""
|
293
|
+
Estimate the RAM usage for a given convolution operation based on input shapes,
|
294
|
+
matching_method, and number of cores.
|
295
|
+
|
296
|
+
Parameters
|
297
|
+
----------
|
298
|
+
shape1 : tuple
|
299
|
+
The shape of the input target.
|
300
|
+
shape2 : tuple
|
301
|
+
The shape of the input template.
|
302
|
+
matching_method : str
|
303
|
+
The method used for the operation.
|
304
|
+
is_gpu : bool, optional
|
305
|
+
Whether the computation is performed on GPU. This factors in FFT
|
306
|
+
plan caching.
|
307
|
+
analyzer_method : str, optional
|
308
|
+
The method used for score analysis.
|
309
|
+
backend : str, optional
|
310
|
+
Backend used for computation.
|
311
|
+
ncores : int
|
312
|
+
The number of CPU cores used for the operation.
|
313
|
+
float_nbytes : int
|
314
|
+
Number of bytes of the used float, e.g. 4 for float32.
|
315
|
+
complex_nbytes : int
|
316
|
+
Number of bytes of the used complex, e.g. 8 for complex64.
|
317
|
+
integer_nbytes : int
|
318
|
+
Number of bytes of the used integer, e.g. 4 for int32.
|
319
|
+
|
320
|
+
Returns
|
321
|
+
-------
|
322
|
+
int
|
323
|
+
The estimated RAM usage for the operation in bytes.
|
324
|
+
|
325
|
+
Notes
|
326
|
+
-----
|
327
|
+
Residual memory from other objects that may remain allocated during
|
328
|
+
template matching, e.g. the full sized target when using splitting,
|
329
|
+
are not considered by this function.
|
330
|
+
|
331
|
+
Raises
|
332
|
+
------
|
333
|
+
ValueError
|
334
|
+
If an unsupported matching_methode is provided.
|
335
|
+
"""
|
336
|
+
if matching_method not in MATCHING_MEMORY_REGISTRY:
|
337
|
+
raise ValueError(
|
338
|
+
f"Supported options are {','.join(MATCHING_MEMORY_REGISTRY.keys())}"
|
339
|
+
)
|
340
|
+
|
341
|
+
convolution_shape, fast_shape, ft_shape = _compute_convolution_shapes(
|
342
|
+
shape1, shape2
|
343
|
+
)
|
344
|
+
|
345
|
+
memory_instance = MATCHING_MEMORY_REGISTRY[matching_method](
|
346
|
+
fast_shape=fast_shape,
|
347
|
+
ft_shape=ft_shape,
|
348
|
+
float_nbytes=float_nbytes,
|
349
|
+
complex_nbytes=complex_nbytes,
|
350
|
+
integer_nbytes=integer_nbytes,
|
351
|
+
)
|
352
|
+
|
353
|
+
nbytes = memory_instance.base_usage() + memory_instance.per_fork() * ncores
|
354
|
+
|
355
|
+
analyzer_instance = MATCHING_MEMORY_REGISTRY.get(analyzer_method, None)
|
356
|
+
if analyzer_instance is not None:
|
357
|
+
analyzer_instance = analyzer_instance(
|
358
|
+
fast_shape=fast_shape,
|
359
|
+
ft_shape=ft_shape,
|
360
|
+
float_nbytes=float_nbytes,
|
361
|
+
complex_nbytes=complex_nbytes,
|
362
|
+
integer_nbytes=integer_nbytes,
|
363
|
+
)
|
364
|
+
nbytes += analyzer_instance.base_usage() + analyzer_instance.per_fork() * ncores
|
365
|
+
|
366
|
+
backend_instance = MATCHING_MEMORY_REGISTRY.get(backend, None)
|
367
|
+
if backend_instance is not None:
|
368
|
+
backend_instance = backend_instance(
|
369
|
+
fast_shape=fast_shape,
|
370
|
+
ft_shape=ft_shape,
|
371
|
+
float_nbytes=float_nbytes,
|
372
|
+
complex_nbytes=complex_nbytes,
|
373
|
+
integer_nbytes=integer_nbytes,
|
374
|
+
)
|
375
|
+
nbytes += backend_instance.base_usage() + backend_instance.per_fork() * ncores
|
376
|
+
|
377
|
+
return nbytes
|
tme/orientations.py
CHANGED
@@ -62,16 +62,16 @@ class Orientations:
|
|
62
62
|
Array with additional orientation details (n, ).
|
63
63
|
"""
|
64
64
|
|
65
|
-
#:
|
65
|
+
#: Array with translations of each orientation (n, d).
|
66
66
|
translations: np.ndarray
|
67
67
|
|
68
|
-
#:
|
68
|
+
#: Array with zyx euler angles of each orientation (n, d).
|
69
69
|
rotations: np.ndarray
|
70
70
|
|
71
|
-
#:
|
71
|
+
#: Array with scores of each orientation (n, ).
|
72
72
|
scores: np.ndarray
|
73
73
|
|
74
|
-
#:
|
74
|
+
#: Array with additional details of each orientation(n, ).
|
75
75
|
details: np.ndarray
|
76
76
|
|
77
77
|
def __post_init__(self):
|
@@ -130,9 +130,21 @@ class Orientations:
|
|
130
130
|
"scores",
|
131
131
|
"details",
|
132
132
|
)
|
133
|
-
kwargs = {attr: getattr(self, attr)[indices] for attr in attributes}
|
133
|
+
kwargs = {attr: getattr(self, attr)[indices].copy() for attr in attributes}
|
134
134
|
return self.__class__(**kwargs)
|
135
135
|
|
136
|
+
def copy(self) -> "Orientations":
|
137
|
+
"""
|
138
|
+
Create a copy of the current class instance.
|
139
|
+
|
140
|
+
Returns
|
141
|
+
-------
|
142
|
+
:py:class:`Orientations`
|
143
|
+
Copy of the class instance.
|
144
|
+
"""
|
145
|
+
indices = np.arange(self.scores.size)
|
146
|
+
return self[indices]
|
147
|
+
|
136
148
|
def to_file(self, filename: str, file_format: type = None, **kwargs) -> None:
|
137
149
|
"""
|
138
150
|
Save the current class instance to a file in the specified format.
|
@@ -146,7 +158,7 @@ class Orientations:
|
|
146
158
|
the file_format from the typical extension. Supported formats are
|
147
159
|
|
148
160
|
+---------------+----------------------------------------------------+
|
149
|
-
| text |
|
161
|
+
| text | pytme's standard tab-separated orientations file |
|
150
162
|
+---------------+----------------------------------------------------+
|
151
163
|
| relion | Creates a STAR file of orientations |
|
152
164
|
+---------------+----------------------------------------------------+
|
@@ -207,11 +219,11 @@ class Orientations:
|
|
207
219
|
with open(filename, mode="w", encoding="utf-8") as ofile:
|
208
220
|
_ = ofile.write(f"{header}\n")
|
209
221
|
for translation, angles, score, detail in self:
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
f"{translation_string}\t{angle_string}\t{score}\t{detail}\n"
|
222
|
+
out_string = (
|
223
|
+
"\t".join([str(x) for x in (*translation, *angles, score, detail)])
|
224
|
+
+ "\n"
|
214
225
|
)
|
226
|
+
_ = ofile.write(out_string)
|
215
227
|
return None
|
216
228
|
|
217
229
|
def _to_dynamo_tbl(
|
@@ -465,8 +477,10 @@ class Orientations:
|
|
465
477
|
|
466
478
|
Notes
|
467
479
|
-----
|
468
|
-
The text file is expected to have a header and data in columns
|
469
|
-
|
480
|
+
The text file is expected to have a header and data in columns. Colums containing
|
481
|
+
the name euler are considered to specify rotations. The second last and last
|
482
|
+
column correspond to score and detail. Its possible to only specify translations,
|
483
|
+
in this case the remaining columns will be filled with trivial values.
|
470
484
|
"""
|
471
485
|
with open(filename, mode="r", encoding="utf-8") as infile:
|
472
486
|
data = [x.strip().split("\t") for x in infile.read().split("\n")]
|
@@ -493,6 +507,32 @@ class Orientations:
|
|
493
507
|
score = np.array(score)
|
494
508
|
detail = np.array(detail)
|
495
509
|
|
510
|
+
if translation.shape[1] == len(header):
|
511
|
+
rotation = np.zeros(translation.shape, dtype=np.float32)
|
512
|
+
score = np.zeros(translation.shape[0], dtype=np.float32)
|
513
|
+
detail = np.zeros(translation.shape[0], dtype=np.float32) - 1
|
514
|
+
|
515
|
+
if rotation.size == 0 and translation.shape[0] != 0:
|
516
|
+
rotation = np.zeros(translation.shape, dtype=np.float32)
|
517
|
+
|
518
|
+
header_order = tuple(x for x in header if x in ascii_lowercase)
|
519
|
+
header_order = zip(header_order, range(len(header_order)))
|
520
|
+
sort_order = tuple(
|
521
|
+
x[1] for x in sorted(header_order, key=lambda x: x[0], reverse=True)
|
522
|
+
)
|
523
|
+
translation = translation[..., sort_order]
|
524
|
+
|
525
|
+
header_order = tuple(
|
526
|
+
x
|
527
|
+
for x in header
|
528
|
+
if "euler" in x and x.replace("euler_", "") in ascii_lowercase
|
529
|
+
)
|
530
|
+
header_order = zip(header_order, range(len(header_order)))
|
531
|
+
sort_order = tuple(
|
532
|
+
x[1] for x in sorted(header_order, key=lambda x: x[0], reverse=True)
|
533
|
+
)
|
534
|
+
rotation = rotation[..., sort_order]
|
535
|
+
|
496
536
|
return translation, rotation, score, detail
|
497
537
|
|
498
538
|
@staticmethod
|
tme/parser.py
CHANGED
@@ -137,8 +137,7 @@ class Parser(ABC):
|
|
137
137
|
|
138
138
|
class PDBParser(Parser):
|
139
139
|
"""
|
140
|
-
|
141
|
-
This class is specifically designed to work with PDB file format.
|
140
|
+
Convert PDB file data into a dictionary representation [1]_.
|
142
141
|
|
143
142
|
References
|
144
143
|
----------
|
@@ -228,8 +227,8 @@ class PDBParser(Parser):
|
|
228
227
|
|
229
228
|
class MMCIFParser(Parser):
|
230
229
|
"""
|
231
|
-
|
232
|
-
|
230
|
+
Convert MMCIF file data into a dictionary representation. This implementation
|
231
|
+
heavily relies on the atomium library [1]_.
|
233
232
|
|
234
233
|
References
|
235
234
|
----------
|
tme/preprocessing/_utils.py
CHANGED
@@ -5,12 +5,13 @@
|
|
5
5
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
6
|
"""
|
7
7
|
|
8
|
-
from typing import Tuple
|
8
|
+
from typing import Tuple, List
|
9
9
|
|
10
10
|
import numpy as np
|
11
|
-
from numpy.typing import NDArray
|
12
11
|
|
13
|
-
from ..backends import backend
|
12
|
+
from ..backends import backend as be
|
13
|
+
from ..backends import NumpyFFTWBackend
|
14
|
+
from ..types import BackendArray, NDArray
|
14
15
|
from ..matching_utils import euler_to_rotationmatrix
|
15
16
|
|
16
17
|
|
@@ -93,18 +94,27 @@ def frequency_grid_at_angle(
|
|
93
94
|
tilt_shape = compute_tilt_shape(
|
94
95
|
shape=shape, opening_axis=opening_axis, reduce_dim=False
|
95
96
|
)
|
96
|
-
|
97
|
+
|
98
|
+
if angle == 0:
|
99
|
+
index_grid = fftfreqn(
|
100
|
+
tuple(x for x in tilt_shape if x != 1),
|
101
|
+
sampling_rate=1,
|
102
|
+
compute_euclidean_norm=True,
|
103
|
+
)
|
104
|
+
|
97
105
|
if angle != 0:
|
98
106
|
angles = np.zeros(len(shape))
|
99
107
|
angles[tilt_axis] = angle
|
100
108
|
rotation_matrix = euler_to_rotationmatrix(np.roll(angles, opening_axis - 1))
|
109
|
+
|
110
|
+
index_grid = fftfreqn(tilt_shape, sampling_rate=None)
|
101
111
|
index_grid = np.einsum("ij,j...->i...", rotation_matrix, index_grid)
|
112
|
+
norm = np.multiply(sampling_rate, shape).astype(int)
|
102
113
|
|
103
|
-
|
114
|
+
index_grid = np.divide(index_grid.T, norm).T
|
115
|
+
index_grid = np.squeeze(index_grid)
|
116
|
+
index_grid = np.linalg.norm(index_grid, axis=(0))
|
104
117
|
|
105
|
-
index_grid = np.multiply(index_grid.T, norm).T
|
106
|
-
index_grid = np.squeeze(index_grid)
|
107
|
-
index_grid = np.linalg.norm(index_grid, axis=(0))
|
108
118
|
return index_grid
|
109
119
|
|
110
120
|
|
@@ -113,9 +123,10 @@ def fftfreqn(
|
|
113
123
|
sampling_rate: Tuple[float],
|
114
124
|
compute_euclidean_norm: bool = False,
|
115
125
|
shape_is_real_fourier: bool = False,
|
126
|
+
return_sparse_grid: bool = False,
|
116
127
|
) -> NDArray:
|
117
128
|
"""
|
118
|
-
Generate the n-dimensional discrete Fourier
|
129
|
+
Generate the n-dimensional discrete Fourier transform sample frequencies.
|
119
130
|
|
120
131
|
Parameters:
|
121
132
|
-----------
|
@@ -133,56 +144,74 @@ def fftfreqn(
|
|
133
144
|
NDArray
|
134
145
|
The sample frequencies.
|
135
146
|
"""
|
136
|
-
|
137
|
-
|
138
|
-
norm =
|
147
|
+
# There is no real need to have these operations on GPU right now
|
148
|
+
temp_backend = NumpyFFTWBackend()
|
149
|
+
norm = temp_backend.full(len(shape), fill_value=1)
|
150
|
+
center = temp_backend.astype(temp_backend.divide(shape, 2), temp_backend._int_dtype)
|
139
151
|
if sampling_rate is not None:
|
140
|
-
norm =
|
152
|
+
norm = temp_backend.astype(temp_backend.multiply(shape, sampling_rate), int)
|
141
153
|
|
142
154
|
if shape_is_real_fourier:
|
143
|
-
center[-1] = 0
|
144
|
-
norm[-1] = 1
|
155
|
+
center[-1], norm[-1] = 0, 1
|
145
156
|
if sampling_rate is not None:
|
146
157
|
norm[-1] = (shape[-1] - 1) * 2 * sampling_rate
|
147
158
|
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
159
|
+
grids = []
|
160
|
+
for i, x in enumerate(shape):
|
161
|
+
baseline_dims = tuple(1 if i != t else x for t in range(len(shape)))
|
162
|
+
grid = (temp_backend.arange(x) - center[i]) / norm[i]
|
163
|
+
grids.append(temp_backend.reshape(grid, baseline_dims))
|
152
164
|
|
153
165
|
if compute_euclidean_norm:
|
154
|
-
|
155
|
-
|
156
|
-
|
166
|
+
grids = sum(temp_backend.square(x) for x in grids)
|
167
|
+
grids = temp_backend.sqrt(grids, out=grids)
|
168
|
+
return grids
|
169
|
+
|
170
|
+
if return_sparse_grid:
|
171
|
+
return grids
|
157
172
|
|
158
|
-
|
173
|
+
grid_flesh = temp_backend.full(shape, fill_value=1)
|
174
|
+
grids = temp_backend.stack(tuple(grid * grid_flesh for grid in grids))
|
159
175
|
|
176
|
+
return grids
|
160
177
|
|
161
|
-
|
178
|
+
|
179
|
+
def crop_real_fourier(data: BackendArray) -> BackendArray:
|
162
180
|
"""
|
163
181
|
Crop the real part of a Fourier transform.
|
164
182
|
|
165
183
|
Parameters:
|
166
184
|
-----------
|
167
|
-
data :
|
185
|
+
data : BackendArray
|
168
186
|
The Fourier transformed data.
|
169
187
|
|
170
188
|
Returns:
|
171
189
|
--------
|
172
|
-
|
190
|
+
BackendArray
|
173
191
|
The cropped data.
|
174
192
|
"""
|
175
193
|
stop = 1 + (data.shape[-1] // 2)
|
176
194
|
return data[..., :stop]
|
177
195
|
|
178
196
|
|
179
|
-
def
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
197
|
+
def compute_fourier_shape(
|
198
|
+
shape: Tuple[int], shape_is_real_fourier: bool = False
|
199
|
+
) -> List[int]:
|
200
|
+
if shape_is_real_fourier:
|
201
|
+
return shape
|
202
|
+
shape = [int(x) for x in shape]
|
203
|
+
shape[-1] = 1 + shape[-1] // 2
|
204
|
+
return shape
|
205
|
+
|
206
|
+
|
207
|
+
def shift_fourier(
|
208
|
+
data: BackendArray, shape_is_real_fourier: bool = False
|
209
|
+
) -> BackendArray:
|
210
|
+
shape = be.to_backend_array(data.shape)
|
211
|
+
shift = be.add(be.divide(shape, 2), be.mod(shape, 2))
|
212
|
+
shift = [int(x) for x in shift]
|
184
213
|
if shape_is_real_fourier:
|
185
214
|
shift[-1] = 0
|
186
215
|
|
187
|
-
data =
|
216
|
+
data = be.roll(data, shift, tuple(i for i in range(len(shift))))
|
188
217
|
return data
|
tme/preprocessing/compose.py
CHANGED
@@ -7,7 +7,7 @@
|
|
7
7
|
|
8
8
|
from typing import Tuple, Dict
|
9
9
|
|
10
|
-
from tme.backends import backend
|
10
|
+
from tme.backends import backend as be
|
11
11
|
|
12
12
|
|
13
13
|
class Compose:
|
@@ -42,9 +42,13 @@ class Compose:
|
|
42
42
|
kwargs.update(meta)
|
43
43
|
ret = transform(**kwargs)
|
44
44
|
|
45
|
+
if "data" not in ret:
|
46
|
+
continue
|
47
|
+
|
45
48
|
if ret.get("is_multiplicative_filter", False):
|
46
|
-
|
47
|
-
ret["
|
49
|
+
prev_data = meta.pop("data")
|
50
|
+
ret["data"] = be.multiply(ret["data"], prev_data, out=ret["data"])
|
51
|
+
ret["merge"], prev_data = None, None
|
48
52
|
|
49
53
|
meta = ret
|
50
54
|
|