pytme 0.1.5__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.1.5.data/scripts/estimate_ram_usage.py +81 -0
- pytme-0.1.5.data/scripts/match_template.py +744 -0
- pytme-0.1.5.data/scripts/postprocess.py +279 -0
- pytme-0.1.5.data/scripts/preprocess.py +93 -0
- pytme-0.1.5.data/scripts/preprocessor_gui.py +729 -0
- pytme-0.1.5.dist-info/LICENSE +153 -0
- pytme-0.1.5.dist-info/METADATA +69 -0
- pytme-0.1.5.dist-info/RECORD +63 -0
- pytme-0.1.5.dist-info/WHEEL +5 -0
- pytme-0.1.5.dist-info/entry_points.txt +6 -0
- pytme-0.1.5.dist-info/top_level.txt +2 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +81 -0
- scripts/match_template.py +744 -0
- scripts/match_template_devel.py +788 -0
- scripts/postprocess.py +279 -0
- scripts/preprocess.py +93 -0
- scripts/preprocessor_gui.py +729 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer.py +1144 -0
- tme/backends/__init__.py +134 -0
- tme/backends/cupy_backend.py +309 -0
- tme/backends/matching_backend.py +1154 -0
- tme/backends/npfftw_backend.py +763 -0
- tme/backends/pytorch_backend.py +526 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2314 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/helpers.py +881 -0
- tme/matching_data.py +377 -0
- tme/matching_exhaustive.py +1553 -0
- tme/matching_memory.py +382 -0
- tme/matching_optimization.py +1123 -0
- tme/matching_utils.py +1180 -0
- tme/parser.py +429 -0
- tme/preprocessor.py +1291 -0
- tme/scoring.py +866 -0
- tme/structure.py +1428 -0
- tme/types.py +10 -0
@@ -0,0 +1,788 @@
|
|
1
|
+
#!python3
|
2
|
+
""" CLI interface for basic pyTME template matching functions.
|
3
|
+
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
5
|
+
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
|
+
"""
|
8
|
+
import os
|
9
|
+
import argparse
|
10
|
+
import warnings
|
11
|
+
import importlib.util
|
12
|
+
from sys import exit
|
13
|
+
from time import time
|
14
|
+
from copy import deepcopy
|
15
|
+
from os.path import abspath
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
|
19
|
+
from tme import Density, Preprocessor, __version__
|
20
|
+
from tme.matching_utils import (
|
21
|
+
get_rotation_matrices,
|
22
|
+
compute_parallelization_schedule,
|
23
|
+
euler_from_rotationmatrix,
|
24
|
+
scramble_phases,
|
25
|
+
generate_tempfile_name,
|
26
|
+
write_pickle,
|
27
|
+
)
|
28
|
+
from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
|
29
|
+
from tme.matching_data import MatchingData
|
30
|
+
from tme.packaging import KernelFitting, GaussianKernel
|
31
|
+
from tme.analyzer import (
|
32
|
+
MaxScoreOverRotations,
|
33
|
+
PeakCallerMaximumFilter,
|
34
|
+
)
|
35
|
+
from tme.backends import backend
|
36
|
+
|
37
|
+
# from tme.matching_exhaustive import register_matching_exhaustive
|
38
|
+
# from tme.matching_memory import CCMemoryUsage
|
39
|
+
from tme.scoring import corr3_setup, corr3_scoring
|
40
|
+
# register_matching_exhaustive(
|
41
|
+
# matching = "CC3",
|
42
|
+
# matching_setup = corr3_setup,
|
43
|
+
# matching_scoring = corr3_scoring,
|
44
|
+
# memory_class = CCMemoryUsage
|
45
|
+
# )
|
46
|
+
|
47
|
+
|
48
|
+
|
49
|
+
def get_func_fullname(func) -> str:
|
50
|
+
"""Returns the full name of the given function, including its module."""
|
51
|
+
return f"<function '{func.__module__}.{func.__name__}'>"
|
52
|
+
|
53
|
+
|
54
|
+
def print_block(name: str, data: dict, label_width=20) -> None:
|
55
|
+
"""Prints a formatted block of information."""
|
56
|
+
print(f"\n> {name}")
|
57
|
+
for key, value in data.items():
|
58
|
+
formatted_value = str(value)
|
59
|
+
print(f" - {key + ':':<{label_width}} {formatted_value}")
|
60
|
+
|
61
|
+
|
62
|
+
def print_entry() -> None:
|
63
|
+
width = 80
|
64
|
+
text = f" pyTME v{__version__} "
|
65
|
+
padding_total = width - len(text) - 2
|
66
|
+
padding_left = padding_total // 2
|
67
|
+
padding_right = padding_total - padding_left
|
68
|
+
|
69
|
+
print("*" * width)
|
70
|
+
print(f"*{ ' ' * padding_left }{text}{ ' ' * padding_right }*")
|
71
|
+
print("*" * width)
|
72
|
+
|
73
|
+
|
74
|
+
def check_positive(value):
|
75
|
+
ivalue = float(value)
|
76
|
+
if ivalue <= 0:
|
77
|
+
raise argparse.ArgumentTypeError("%s is an invalid positive float." % value)
|
78
|
+
return ivalue
|
79
|
+
|
80
|
+
|
81
|
+
def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
|
82
|
+
"""
|
83
|
+
Loadsa mask in CCP4/MRC format and assess whether the sampling_rate
|
84
|
+
and shape matches its target.
|
85
|
+
|
86
|
+
Parameters
|
87
|
+
----------
|
88
|
+
mask_target : Density
|
89
|
+
Object the mask should be applied to
|
90
|
+
mask_path : str
|
91
|
+
Path to the mask in CCP4/MRC format.
|
92
|
+
kwargs : dict, optional
|
93
|
+
Keyword arguments passed to :py:meth:`tme.density.Density.from_file`.
|
94
|
+
Raise
|
95
|
+
-----
|
96
|
+
ValueError
|
97
|
+
If shape or sampling rate do not match between mask_target and mask
|
98
|
+
|
99
|
+
Returns
|
100
|
+
-------
|
101
|
+
Density
|
102
|
+
A density instance if the mask was validated and loaded otherwise None
|
103
|
+
"""
|
104
|
+
mask = mask_path
|
105
|
+
if mask is not None:
|
106
|
+
mask = Density.from_file(mask, **kwargs)
|
107
|
+
mask.origin = deepcopy(mask_target.origin)
|
108
|
+
if not np.allclose(mask.shape, mask_target.shape):
|
109
|
+
raise ValueError(
|
110
|
+
f"Expected shape of {mask_path} was {mask_target.shape},"
|
111
|
+
f" got f{mask.shape}"
|
112
|
+
)
|
113
|
+
if not np.allclose(mask.sampling_rate, mask_target.sampling_rate):
|
114
|
+
raise ValueError(
|
115
|
+
f"Expected sampling_rate of {mask_path} was {mask_target.sampling_rate}"
|
116
|
+
f", got f{mask.sampling_rate}"
|
117
|
+
)
|
118
|
+
return mask
|
119
|
+
|
120
|
+
|
121
|
+
def crop_data(data: Density, cutoff: float, data_mask: Density = None) -> bool:
|
122
|
+
"""
|
123
|
+
Crop the provided data and mask to a smaller box based on a cutoff value.
|
124
|
+
|
125
|
+
Parameters
|
126
|
+
----------
|
127
|
+
data : Density
|
128
|
+
The data that should be cropped.
|
129
|
+
cutoff : float
|
130
|
+
The threshold value to determine which parts of the data should be kept.
|
131
|
+
data_mask : Density, optional
|
132
|
+
A mask for the data that should be cropped.
|
133
|
+
|
134
|
+
Returns
|
135
|
+
-------
|
136
|
+
bool
|
137
|
+
Returns True if the data was adjusted (cropped), otherwise returns False.
|
138
|
+
|
139
|
+
Notes
|
140
|
+
-----
|
141
|
+
Cropping is performed in place.
|
142
|
+
"""
|
143
|
+
if cutoff is None:
|
144
|
+
return False
|
145
|
+
|
146
|
+
box = data.trim_box(cutoff=cutoff)
|
147
|
+
box_mask = box
|
148
|
+
if data_mask is not None:
|
149
|
+
box_mask = data_mask.trim_box(cutoff=cutoff)
|
150
|
+
box = tuple(
|
151
|
+
slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
|
152
|
+
for arr, mask in zip(box, box_mask)
|
153
|
+
)
|
154
|
+
if box == tuple(slice(0, x) for x in data.shape):
|
155
|
+
return False
|
156
|
+
|
157
|
+
data.adjust_box(box)
|
158
|
+
|
159
|
+
if data_mask:
|
160
|
+
data_mask.adjust_box(box)
|
161
|
+
|
162
|
+
return True
|
163
|
+
|
164
|
+
|
165
|
+
def parse_args():
|
166
|
+
parser = argparse.ArgumentParser(description="Perform template matching.")
|
167
|
+
parser.add_argument(
|
168
|
+
"-m",
|
169
|
+
"--target",
|
170
|
+
dest="target",
|
171
|
+
type=str,
|
172
|
+
required=True,
|
173
|
+
help="Path to a target in CCP4/MRC format.",
|
174
|
+
),
|
175
|
+
parser.add_argument(
|
176
|
+
"--target_mask",
|
177
|
+
dest="target_mask",
|
178
|
+
type=str,
|
179
|
+
required=False,
|
180
|
+
help="Path to a mask for the target target in CCP4/MRC format.",
|
181
|
+
),
|
182
|
+
parser.add_argument(
|
183
|
+
"--cutoff_target",
|
184
|
+
dest="cutoff_target",
|
185
|
+
type=float,
|
186
|
+
required=False,
|
187
|
+
help="Target contour level (used for cropping).",
|
188
|
+
default=None,
|
189
|
+
),
|
190
|
+
parser.add_argument(
|
191
|
+
"--cutoff_template",
|
192
|
+
dest="cutoff_template",
|
193
|
+
type=float,
|
194
|
+
required=False,
|
195
|
+
help="Template contour level (used for cropping).",
|
196
|
+
default=None,
|
197
|
+
),
|
198
|
+
parser.add_argument(
|
199
|
+
"-i",
|
200
|
+
"--template",
|
201
|
+
dest="template",
|
202
|
+
type=str,
|
203
|
+
required=True,
|
204
|
+
help="Path to a template in PDB/MMCIF or CCP4/MRC format.",
|
205
|
+
),
|
206
|
+
parser.add_argument(
|
207
|
+
"--template_mask",
|
208
|
+
dest="template_mask",
|
209
|
+
type=str,
|
210
|
+
required=False,
|
211
|
+
help="Path to a mask for the template in CCP4/MRC format.",
|
212
|
+
),
|
213
|
+
parser.add_argument(
|
214
|
+
"-o",
|
215
|
+
dest="output",
|
216
|
+
type=str,
|
217
|
+
required=False,
|
218
|
+
default="output.pickle",
|
219
|
+
help="Path to output pickle file.",
|
220
|
+
)
|
221
|
+
parser.add_argument(
|
222
|
+
"-s",
|
223
|
+
dest="score",
|
224
|
+
type=str,
|
225
|
+
default="CC",
|
226
|
+
help="Template matching scoring function.",
|
227
|
+
choices=MATCHING_EXHAUSTIVE_REGISTER.keys(),
|
228
|
+
)
|
229
|
+
parser.add_argument(
|
230
|
+
"-n",
|
231
|
+
dest="cores",
|
232
|
+
required=False,
|
233
|
+
type=int,
|
234
|
+
default=4,
|
235
|
+
help="Number of cores used for template matching.",
|
236
|
+
)
|
237
|
+
parser.add_argument(
|
238
|
+
"-r",
|
239
|
+
"--ram",
|
240
|
+
dest="ram",
|
241
|
+
required=False,
|
242
|
+
type=int,
|
243
|
+
default=None,
|
244
|
+
help="Amount of RAM that can be used in bytes.",
|
245
|
+
)
|
246
|
+
parser.add_argument(
|
247
|
+
"-a",
|
248
|
+
dest="angular_sampling",
|
249
|
+
type=check_positive,
|
250
|
+
default=40.0,
|
251
|
+
help="Angular sampling rate for template matching. "
|
252
|
+
"A lower number yields more rotations.",
|
253
|
+
)
|
254
|
+
parser.add_argument(
|
255
|
+
"-p",
|
256
|
+
dest="peak_calling",
|
257
|
+
action="store_true",
|
258
|
+
default=False,
|
259
|
+
help="When set perform peak calling instead of score aggregation.",
|
260
|
+
)
|
261
|
+
parser.add_argument(
|
262
|
+
"--use_gpu",
|
263
|
+
dest="use_gpu",
|
264
|
+
action="store_true",
|
265
|
+
default=False,
|
266
|
+
help="Whether to perform computations on the GPU.",
|
267
|
+
)
|
268
|
+
parser.add_argument(
|
269
|
+
"--gpu_indices",
|
270
|
+
dest="gpu_indices",
|
271
|
+
type=str,
|
272
|
+
default=None,
|
273
|
+
help="Comma-separated list of GPU indices to use. For example,"
|
274
|
+
" 0,1 for the first and second GPU. Only used if --use_gpu is set."
|
275
|
+
" If not provided but --use_gpu is set, CUDA_VISIBLE_DEVICES will"
|
276
|
+
" be respected.",
|
277
|
+
)
|
278
|
+
parser.add_argument(
|
279
|
+
"--invert_target_contrast",
|
280
|
+
dest="invert_target_contrast",
|
281
|
+
action="store_true",
|
282
|
+
default=False,
|
283
|
+
help="Invert the target contrast via multiplication with negative one and"
|
284
|
+
" linear rescaling between zero and one. Note that this might lead to"
|
285
|
+
" different baseline scores of individual target splits when using"
|
286
|
+
" unnormalized scores. This option is intended for targets, where the"
|
287
|
+
" object to-be-matched has negative values, i.e. tomograms.",
|
288
|
+
)
|
289
|
+
parser.add_argument(
|
290
|
+
"--no_edge_padding",
|
291
|
+
dest="no_edge_padding",
|
292
|
+
action="store_true",
|
293
|
+
default=False,
|
294
|
+
help="Whether to pad the edges of the target. This is useful, if the target"
|
295
|
+
" has a well defined bounding box, e.g. a density map.",
|
296
|
+
)
|
297
|
+
parser.add_argument(
|
298
|
+
"--no_fourier_padding",
|
299
|
+
dest="no_fourier_padding",
|
300
|
+
action="store_true",
|
301
|
+
default=False,
|
302
|
+
help="Whether input arrays should be zero-padded to the full convolution shape"
|
303
|
+
" for numerical stability.",
|
304
|
+
)
|
305
|
+
parser.add_argument(
|
306
|
+
"--scramble_phases",
|
307
|
+
dest="scramble_phases",
|
308
|
+
action="store_true",
|
309
|
+
default=False,
|
310
|
+
help="Whether to phase scramble the template for subsequent normalization.",
|
311
|
+
)
|
312
|
+
parser.add_argument(
|
313
|
+
"--interpolation_order",
|
314
|
+
dest="interpolation_order",
|
315
|
+
required=False,
|
316
|
+
type=int,
|
317
|
+
default=3,
|
318
|
+
help="Spline interpolation used during rotations. If less than zero"
|
319
|
+
" no interpolation is performed.",
|
320
|
+
)
|
321
|
+
parser.add_argument(
|
322
|
+
"--use_mixed_precision",
|
323
|
+
dest="use_mixed_precision",
|
324
|
+
action="store_true",
|
325
|
+
default=False,
|
326
|
+
help="Use float16 for real values operations where possible.",
|
327
|
+
)
|
328
|
+
parser.add_argument(
|
329
|
+
"--use_memmap",
|
330
|
+
dest="use_memmap",
|
331
|
+
action="store_true",
|
332
|
+
default=False,
|
333
|
+
help="Use memmaps to offload large data objects to disk. This is"
|
334
|
+
" particularly useful for large inputs when using --use_gpu..",
|
335
|
+
)
|
336
|
+
parser.add_argument(
|
337
|
+
"--temp_directory",
|
338
|
+
dest="temp_directory",
|
339
|
+
default=None,
|
340
|
+
help="Directory for temporary objects. Faster I/O typically improves runtime.",
|
341
|
+
)
|
342
|
+
parser.add_argument(
|
343
|
+
"--gaussian_sigma",
|
344
|
+
dest="gaussian_sigma",
|
345
|
+
type=float,
|
346
|
+
required=False,
|
347
|
+
help="Sigma parameter for Gaussian filtering the template.",
|
348
|
+
)
|
349
|
+
|
350
|
+
parser.add_argument(
|
351
|
+
"--bandpass_band",
|
352
|
+
dest="bandpass_band",
|
353
|
+
type=str,
|
354
|
+
required=False,
|
355
|
+
help="Comma separated start and stop frequency for bandpass filtering the"
|
356
|
+
" template, e.g. 0.1, 0.5",
|
357
|
+
)
|
358
|
+
parser.add_argument(
|
359
|
+
"--bandpass_smooth",
|
360
|
+
dest="bandpass_smooth",
|
361
|
+
type=float,
|
362
|
+
required=False,
|
363
|
+
default=None,
|
364
|
+
help="Smooth parameter for the bandpass filter.",
|
365
|
+
)
|
366
|
+
|
367
|
+
parser.add_argument(
|
368
|
+
"--tilt_range",
|
369
|
+
dest="tilt_range",
|
370
|
+
type=str,
|
371
|
+
required=False,
|
372
|
+
help="Comma separated start and stop stage tilt angle, e.g. '50,45'. Used"
|
373
|
+
" to create a wedge mask to be applied to the template.",
|
374
|
+
)
|
375
|
+
parser.add_argument(
|
376
|
+
"--tilt_step",
|
377
|
+
dest="tilt_step",
|
378
|
+
type=float,
|
379
|
+
required=False,
|
380
|
+
default=None,
|
381
|
+
help="Step size between tilts, e.g. '5'. When set a more accurate"
|
382
|
+
" wedge mask will be computed.",
|
383
|
+
)
|
384
|
+
parser.add_argument(
|
385
|
+
"--wedge_smooth",
|
386
|
+
dest="wedge_smooth",
|
387
|
+
type=float,
|
388
|
+
required=False,
|
389
|
+
default=None,
|
390
|
+
help="Gaussian sigma used to smooth the wedge mask.",
|
391
|
+
)
|
392
|
+
|
393
|
+
args = parser.parse_args()
|
394
|
+
|
395
|
+
if args.interpolation_order < 0:
|
396
|
+
args.interpolation_order = None
|
397
|
+
|
398
|
+
if args.temp_directory is None:
|
399
|
+
default = abspath(".")
|
400
|
+
if os.environ.get("TMPDIR", None) is not None:
|
401
|
+
default = os.environ.get("TMPDIR")
|
402
|
+
args.temp_directory = default
|
403
|
+
|
404
|
+
os.environ["TMPDIR"] = args.temp_directory
|
405
|
+
|
406
|
+
args.pad_target_edges = not args.no_edge_padding
|
407
|
+
args.pad_fourier = not args.no_fourier_padding
|
408
|
+
|
409
|
+
if args.score not in MATCHING_EXHAUSTIVE_REGISTER:
|
410
|
+
raise ValueError(
|
411
|
+
f"score has to be one of {', '.join(MATCHING_EXHAUSTIVE_REGISTER.keys())}"
|
412
|
+
)
|
413
|
+
|
414
|
+
gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
415
|
+
if args.gpu_indices is not None:
|
416
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_indices
|
417
|
+
|
418
|
+
if args.use_gpu:
|
419
|
+
gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
420
|
+
if gpu_devices is None:
|
421
|
+
# raise ValueError(
|
422
|
+
# "No GPU indices provided and CUDA_VISIBLE_DEVICES is not set."
|
423
|
+
# )
|
424
|
+
print(
|
425
|
+
"No GPU indices provided and CUDA_VISIBLE_DEVICES is not set.",
|
426
|
+
"Assuming device 0.",
|
427
|
+
)
|
428
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
429
|
+
args.gpu_indices = [
|
430
|
+
int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
431
|
+
]
|
432
|
+
|
433
|
+
return args
|
434
|
+
|
435
|
+
|
436
|
+
def main():
|
437
|
+
args = parse_args()
|
438
|
+
print_entry()
|
439
|
+
|
440
|
+
target = Density.from_file(args.target, use_memmap=True)
|
441
|
+
|
442
|
+
try:
|
443
|
+
template = Density.from_file(args.template)
|
444
|
+
except Exception:
|
445
|
+
template = Density.from_structure(
|
446
|
+
filename_or_structure=args.template,
|
447
|
+
sampling_rate=target.sampling_rate,
|
448
|
+
)
|
449
|
+
|
450
|
+
if not np.allclose(target.sampling_rate, template.sampling_rate):
|
451
|
+
print(
|
452
|
+
f"Resampling template to {target.sampling_rate}. "
|
453
|
+
"Consider providing a template with the same sampling rate as the target."
|
454
|
+
)
|
455
|
+
template = template.resample(target.sampling_rate, order=3)
|
456
|
+
|
457
|
+
template_mask = load_and_validate_mask(
|
458
|
+
mask_target=template, mask_path=args.template_mask
|
459
|
+
)
|
460
|
+
target_mask = load_and_validate_mask(
|
461
|
+
mask_target=target, mask_path=args.target_mask, use_memmap=True
|
462
|
+
)
|
463
|
+
|
464
|
+
initial_shape = target.shape
|
465
|
+
is_cropped = crop_data(
|
466
|
+
data=target, data_mask=target_mask, cutoff=args.cutoff_target
|
467
|
+
)
|
468
|
+
print_block(
|
469
|
+
name="Target",
|
470
|
+
data={
|
471
|
+
"Inital Shape": initial_shape,
|
472
|
+
"Sampling Rate": tuple(np.round(target.sampling_rate, 2)),
|
473
|
+
"Final Shape": target.shape,
|
474
|
+
},
|
475
|
+
)
|
476
|
+
if is_cropped:
|
477
|
+
args.target = generate_tempfile_name(suffix=".mrc")
|
478
|
+
target.to_file(args.target)
|
479
|
+
|
480
|
+
if target_mask:
|
481
|
+
args.target_mask = generate_tempfile_name(suffix=".mrc")
|
482
|
+
target_mask.to_file(args.target_mask)
|
483
|
+
print_block(
|
484
|
+
name="Target Mask",
|
485
|
+
data={
|
486
|
+
"Inital Shape": initial_shape,
|
487
|
+
"Sampling Rate": tuple(np.round(target_mask.sampling_rate, 2)),
|
488
|
+
"Final Shape": target_mask.shape,
|
489
|
+
},
|
490
|
+
)
|
491
|
+
|
492
|
+
initial_shape = template.shape
|
493
|
+
_ = crop_data(data=template, data_mask=template_mask, cutoff=args.cutoff_template)
|
494
|
+
template, translation = template.centered(0)
|
495
|
+
|
496
|
+
print_block(
|
497
|
+
name="Template",
|
498
|
+
data={
|
499
|
+
"Inital Shape": initial_shape,
|
500
|
+
"Sampling Rate": tuple(np.round(template.sampling_rate, 2)),
|
501
|
+
"Final Shape": template.shape,
|
502
|
+
},
|
503
|
+
)
|
504
|
+
|
505
|
+
template_filter = {}
|
506
|
+
if args.gaussian_sigma is not None:
|
507
|
+
template.data = Preprocessor().gaussian_filter(
|
508
|
+
sigma=args.gaussian_sigma, template=template.data
|
509
|
+
)
|
510
|
+
|
511
|
+
if args.bandpass_band is not None:
|
512
|
+
bandpass_start, bandpass_stop = [
|
513
|
+
float(x) for x in args.bandpass_band.split(",")
|
514
|
+
]
|
515
|
+
if args.bandpass_smooth is None:
|
516
|
+
args.bandpass_smooth = 0
|
517
|
+
|
518
|
+
template_filter["bandpass_mask"] = {
|
519
|
+
"minimum_frequency": bandpass_start,
|
520
|
+
"maximum_frequency": bandpass_stop,
|
521
|
+
"gaussian_sigma": args.bandpass_smooth,
|
522
|
+
}
|
523
|
+
|
524
|
+
if args.tilt_range is not None:
|
525
|
+
args.wedge_smooth if args.wedge_smooth is not None else 0
|
526
|
+
tilt_start, tilt_stop = [float(x) for x in args.tilt_range.split(",")]
|
527
|
+
|
528
|
+
if args.tilt_step is not None:
|
529
|
+
tilt_angles = np.arange(
|
530
|
+
-tilt_start, tilt_stop + args.tilt_step, args.tilt_step
|
531
|
+
)
|
532
|
+
angles = np.zeros((template.data.ndim, tilt_angles.size))
|
533
|
+
angles[2, :] = tilt_angles
|
534
|
+
template_filter["wedge_mask"] = {
|
535
|
+
"tilt_angles": angles,
|
536
|
+
"sigma": args.wedge_smooth,
|
537
|
+
}
|
538
|
+
else:
|
539
|
+
template_filter["continuous_wedge_mask"] = {
|
540
|
+
"start_tilt": tilt_start,
|
541
|
+
"stop_tilt": tilt_stop,
|
542
|
+
"tilt_axis": 1,
|
543
|
+
"infinite_plane": True,
|
544
|
+
"sigma": args.wedge_smooth,
|
545
|
+
}
|
546
|
+
|
547
|
+
if template_mask is None:
|
548
|
+
enclosing_box = template.minimum_enclosing_box(0, use_geometric_center=False)
|
549
|
+
template_mask = template.empty
|
550
|
+
template_mask.adjust_box(enclosing_box)
|
551
|
+
template_mask.data[:] = 1
|
552
|
+
translation = np.zeros_like(translation)
|
553
|
+
|
554
|
+
template_mask.pad(template.shape, center=False)
|
555
|
+
origin_translation = np.divide(
|
556
|
+
np.subtract(template.origin, template_mask.origin), template.sampling_rate
|
557
|
+
)
|
558
|
+
translation = np.add(translation, origin_translation)
|
559
|
+
|
560
|
+
template_mask = template_mask.rigid_transform(
|
561
|
+
rotation_matrix=np.eye(template_mask.data.ndim),
|
562
|
+
translation=-translation,
|
563
|
+
)
|
564
|
+
|
565
|
+
print_block(
|
566
|
+
name="Template Mask",
|
567
|
+
data={
|
568
|
+
"Inital Shape": initial_shape,
|
569
|
+
"Sampling Rate": tuple(np.round(template_mask.sampling_rate, 2)),
|
570
|
+
"Final Shape": template_mask.shape,
|
571
|
+
},
|
572
|
+
)
|
573
|
+
print("\n" + "-" * 80)
|
574
|
+
|
575
|
+
if args.scramble_phases:
|
576
|
+
template.data = scramble_phases(template.data, noise_proportion=1.0)
|
577
|
+
|
578
|
+
available_memory, ram_scaling = backend.get_available_memory(), 1.0
|
579
|
+
if args.use_gpu:
|
580
|
+
args.cores, ram_scaling = len(args.gpu_indices), 0.85
|
581
|
+
has_torch = importlib.util.find_spec("torch") is not None
|
582
|
+
has_cupy = importlib.util.find_spec("cupy") is not None
|
583
|
+
|
584
|
+
if not has_torch and not has_cupy:
|
585
|
+
raise ValueError(
|
586
|
+
"Found neither CuPy nor PyTorch installation. You need to install"
|
587
|
+
" either to enable GPU support."
|
588
|
+
)
|
589
|
+
|
590
|
+
if args.peak_calling:
|
591
|
+
preferred_backend = "pytorch"
|
592
|
+
if not has_torch:
|
593
|
+
preferred_backend = "cupy"
|
594
|
+
backend.change_backend(backend_name=preferred_backend, device="cuda")
|
595
|
+
else:
|
596
|
+
preferred_backend = "cupy"
|
597
|
+
if not has_cupy:
|
598
|
+
preferred_backend = "pytorch"
|
599
|
+
backend.change_backend(backend_name=preferred_backend, device="cuda")
|
600
|
+
if args.use_mixed_precision and preferred_backend == "pytorch":
|
601
|
+
raise NotImplementedError(
|
602
|
+
"pytorch backend does not yet support mixed precision."
|
603
|
+
" Consider installing CuPy to enable this feature."
|
604
|
+
)
|
605
|
+
elif args.use_mixed_precision:
|
606
|
+
backend.change_backend(
|
607
|
+
backend_name="cupy",
|
608
|
+
default_dtype=backend._array_backend.float16,
|
609
|
+
complex_dtype=backend._array_backend.complex64,
|
610
|
+
default_dtype_int=backend._array_backend.int16,
|
611
|
+
)
|
612
|
+
available_memory = backend.get_available_memory() * args.cores
|
613
|
+
|
614
|
+
if args.ram is None:
|
615
|
+
args.ram = int(ram_scaling * available_memory)
|
616
|
+
|
617
|
+
target_padding = np.zeros_like(template.shape)
|
618
|
+
if args.pad_target_edges:
|
619
|
+
target_padding = template.shape
|
620
|
+
|
621
|
+
template_box = template.shape
|
622
|
+
if not args.pad_fourier:
|
623
|
+
template_box = np.ones(len(template_box), dtype=int)
|
624
|
+
|
625
|
+
callback_class = MaxScoreOverRotations
|
626
|
+
if args.peak_calling:
|
627
|
+
callback_class, kernel_class = KernelFitting, GaussianKernel
|
628
|
+
|
629
|
+
splits, schedule = compute_parallelization_schedule(
|
630
|
+
shape1=target.shape,
|
631
|
+
shape2=template_box,
|
632
|
+
shape1_padding=target_padding,
|
633
|
+
max_cores=args.cores,
|
634
|
+
max_ram=args.ram,
|
635
|
+
split_only_outer=args.use_gpu,
|
636
|
+
matching_method=args.score,
|
637
|
+
analyzer_method=callback_class.__name__,
|
638
|
+
backend=backend._backend_name,
|
639
|
+
float_nbytes=backend.datatype_bytes(backend._default_dtype),
|
640
|
+
complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
|
641
|
+
integer_nbytes=backend.datatype_bytes(backend._default_dtype_int),
|
642
|
+
)
|
643
|
+
|
644
|
+
if splits is None:
|
645
|
+
print(
|
646
|
+
"Found no suitable parallelization schedule. Consider increasing"
|
647
|
+
" available RAM or decreasing number of cores."
|
648
|
+
)
|
649
|
+
exit(-1)
|
650
|
+
|
651
|
+
analyzer_args = {
|
652
|
+
"score_threshold": 0.2,
|
653
|
+
"number_of_peaks": 1000,
|
654
|
+
"convolution_mode": "valid",
|
655
|
+
"use_memmap": args.use_memmap,
|
656
|
+
}
|
657
|
+
|
658
|
+
matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
|
659
|
+
if args.peak_calling:
|
660
|
+
autocorrelation_args = {
|
661
|
+
"convolution_mode": "valid",
|
662
|
+
}
|
663
|
+
|
664
|
+
ndim = template.data.ndim
|
665
|
+
matching_data = MatchingData(
|
666
|
+
target=template.data.copy(), template=template.data.copy()
|
667
|
+
)
|
668
|
+
if template_mask is not None:
|
669
|
+
matching_data.template_mask = template_mask.data
|
670
|
+
matching_data.rotations = np.eye(ndim).reshape(1, ndim, ndim)
|
671
|
+
autocorrelation = scan_subsets(
|
672
|
+
matching_data=matching_data,
|
673
|
+
job_schedule=(1, 1),
|
674
|
+
matching_score=matching_score,
|
675
|
+
matching_setup=matching_setup,
|
676
|
+
callback_class=MaxScoreOverRotations,
|
677
|
+
callback_class_args=autocorrelation_args,
|
678
|
+
target_splits={k: 1 for k in range(ndim)},
|
679
|
+
pad_target_edges=False,
|
680
|
+
interpolation_order=args.interpolation_order,
|
681
|
+
)
|
682
|
+
autocorrelation = autocorrelation[0]
|
683
|
+
autocorrelation /= autocorrelation.max()
|
684
|
+
np.maximum(autocorrelation, 0, out=autocorrelation)
|
685
|
+
autocorrelation = Density(
|
686
|
+
autocorrelation,
|
687
|
+
origin=tuple(0 for _ in range(autocorrelation.ndim)),
|
688
|
+
sampling_rate=1,
|
689
|
+
)
|
690
|
+
autocorrelation = autocorrelation.data
|
691
|
+
analyzer_args["kernel_class"] = GaussianKernel
|
692
|
+
analyzer_args["kernel_box"] = autocorrelation.shape
|
693
|
+
kernel_params = kernel_class.estimate_parameters_batch(
|
694
|
+
backend.to_backend_array(autocorrelation[None, ...])
|
695
|
+
)
|
696
|
+
analyzer_args["kernel_params"] = tuple(param[0] for param in kernel_params)
|
697
|
+
analyzer_args["peak_caller"] = PeakCallerMaximumFilter
|
698
|
+
analyzer_args["min_distance"] = np.max(autocorrelation.shape)
|
699
|
+
|
700
|
+
matching_data = MatchingData(target=target, template=template.data)
|
701
|
+
matching_data.rotations = get_rotation_matrices(args.angular_sampling)
|
702
|
+
|
703
|
+
matching_data.template_filter = template_filter
|
704
|
+
matching_data._invert_target = args.invert_target_contrast
|
705
|
+
if target_mask is not None:
|
706
|
+
matching_data.target_mask = target_mask
|
707
|
+
if template_mask is not None:
|
708
|
+
matching_data.template_mask = template_mask.data
|
709
|
+
|
710
|
+
n_splits = np.prod(list(splits.values()))
|
711
|
+
target_split = ", ".join(
|
712
|
+
[":".join([str(x) for x in axis]) for axis in splits.items()]
|
713
|
+
)
|
714
|
+
gpus_used = 0 if args.gpu_indices is None else len(args.gpu_indices)
|
715
|
+
options = {
|
716
|
+
"CPU Cores": args.cores,
|
717
|
+
"Run on GPU": f"{args.use_gpu} [N={gpus_used}]",
|
718
|
+
"Use Mixed Precision": args.use_mixed_precision,
|
719
|
+
"Assigned Memory [MB]": f"{args.ram // 1e6} [out of {available_memory//1e6}]",
|
720
|
+
"Temporary Directory": args.temp_directory,
|
721
|
+
"Extend Fourier Grid": not args.no_fourier_padding,
|
722
|
+
"Extend Target Edges": not args.no_edge_padding,
|
723
|
+
"Interpolation Order": args.interpolation_order,
|
724
|
+
"Score": f"{args.score}",
|
725
|
+
"Setup Function": f"{get_func_fullname(matching_setup)}",
|
726
|
+
"Scoring Function": f"{get_func_fullname(matching_score)}",
|
727
|
+
"Angular Sampling": f"{args.angular_sampling}"
|
728
|
+
f" [{matching_data.rotations.shape[0]} rotations]",
|
729
|
+
"Scramble Template": args.scramble_phases,
|
730
|
+
"Target Splits": f"{target_split} [N={n_splits}]",
|
731
|
+
}
|
732
|
+
|
733
|
+
print_block(
|
734
|
+
name="Template Matching Options",
|
735
|
+
data=options,
|
736
|
+
label_width=max(len(key) for key in options.keys()) + 2,
|
737
|
+
)
|
738
|
+
|
739
|
+
options = {"Analyzer": callback_class, **analyzer_args}
|
740
|
+
print_block(
|
741
|
+
name="Score Analysis Options",
|
742
|
+
data=options,
|
743
|
+
label_width=max(len(key) for key in options.keys()) + 2,
|
744
|
+
)
|
745
|
+
print("\n" + "-" * 80)
|
746
|
+
|
747
|
+
outer_jobs = f"{schedule[0]} job{'s' if schedule[0] > 1 else ''}"
|
748
|
+
inner_jobs = f"{schedule[1]} core{'s' if schedule[1] > 1 else ''}"
|
749
|
+
n_splits = f"{n_splits} split{'s' if n_splits > 1 else ''}"
|
750
|
+
print(f"\nDistributing {n_splits} on {outer_jobs} each using {inner_jobs}.")
|
751
|
+
|
752
|
+
start = time()
|
753
|
+
print("Running Template Matching. This might take a while ...")
|
754
|
+
candidates = scan_subsets(
|
755
|
+
matching_data=matching_data,
|
756
|
+
job_schedule=schedule,
|
757
|
+
matching_score=matching_score,
|
758
|
+
matching_setup=matching_setup,
|
759
|
+
callback_class=callback_class,
|
760
|
+
callback_class_args=analyzer_args,
|
761
|
+
target_splits=splits,
|
762
|
+
pad_target_edges=args.pad_target_edges,
|
763
|
+
pad_fourier=args.pad_fourier,
|
764
|
+
interpolation_order=args.interpolation_order,
|
765
|
+
)
|
766
|
+
|
767
|
+
candidates = list(candidates) if candidates is not None else []
|
768
|
+
if callback_class == MaxScoreOverRotations:
|
769
|
+
with warnings.catch_warnings():
|
770
|
+
warnings.simplefilter("ignore", category=UserWarning)
|
771
|
+
candidates[3] = {
|
772
|
+
x: euler_from_rotationmatrix(
|
773
|
+
np.frombuffer(i, dtype=matching_data.rotations.dtype).reshape(
|
774
|
+
candidates[0].ndim, candidates[0].ndim
|
775
|
+
)
|
776
|
+
)
|
777
|
+
for i, x in candidates[3].items()
|
778
|
+
}
|
779
|
+
|
780
|
+
candidates.append((target.origin, template.origin, target.sampling_rate, args))
|
781
|
+
write_pickle(data=candidates, filename=args.output)
|
782
|
+
|
783
|
+
runtime = time() - start
|
784
|
+
print(f"\nRuntime real: {runtime:.3f}s user: {(runtime * args.cores):.3f}s.")
|
785
|
+
|
786
|
+
|
787
|
+
if __name__ == "__main__":
|
788
|
+
main()
|