pytme 0.1.9__cp311-cp311-macosx_14_0_arm64.whl → 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {pytme-0.1.9.data → pytme-0.2.0b0.data}/scripts/match_template.py +148 -126
  2. pytme-0.2.0b0.data/scripts/postprocess.py +570 -0
  3. {pytme-0.1.9.data → pytme-0.2.0b0.data}/scripts/preprocessor_gui.py +244 -60
  4. {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/METADATA +3 -1
  5. pytme-0.2.0b0.dist-info/RECORD +66 -0
  6. {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +218 -0
  8. scripts/match_template.py +148 -126
  9. scripts/match_template_filters.py +852 -0
  10. scripts/postprocess.py +380 -435
  11. scripts/preprocessor_gui.py +244 -60
  12. scripts/refine_matches.py +218 -0
  13. tme/__init__.py +2 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +545 -78
  16. tme/backends/cupy_backend.py +80 -15
  17. tme/backends/npfftw_backend.py +33 -2
  18. tme/backends/pytorch_backend.py +15 -7
  19. tme/density.py +156 -63
  20. tme/extensions.cpython-311-darwin.so +0 -0
  21. tme/matching_constrained.py +195 -0
  22. tme/matching_data.py +74 -33
  23. tme/matching_exhaustive.py +351 -208
  24. tme/matching_memory.py +1 -0
  25. tme/matching_optimization.py +728 -651
  26. tme/matching_utils.py +152 -8
  27. tme/orientations.py +561 -0
  28. tme/preprocessor.py +21 -18
  29. tme/structure.py +2 -37
  30. pytme-0.1.9.data/scripts/postprocess.py +0 -625
  31. pytme-0.1.9.dist-info/RECORD +0 -61
  32. {pytme-0.1.9.data → pytme-0.2.0b0.data}/scripts/estimate_ram_usage.py +0 -0
  33. {pytme-0.1.9.data → pytme-0.2.0b0.data}/scripts/preprocess.py +0 -0
  34. {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/LICENSE +0 -0
  35. {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/entry_points.txt +0 -0
  36. {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,852 @@
1
+ #!python3
2
+ """ CLI interface for basic pyTME template matching functions.
3
+
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
+
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ """
8
+ import os
9
+ import argparse
10
+ import warnings
11
+ import importlib.util
12
+ from sys import exit
13
+ from time import time
14
+ from typing import Tuple
15
+ from copy import deepcopy
16
+ from os.path import abspath
17
+
18
+ import numpy as np
19
+
20
+ from tme import Density, __version__
21
+ from tme.matching_utils import (
22
+ get_rotation_matrices,
23
+ compute_parallelization_schedule,
24
+ euler_from_rotationmatrix,
25
+ scramble_phases,
26
+ generate_tempfile_name,
27
+ write_pickle,
28
+ )
29
+ from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
30
+ from tme.matching_data import MatchingData
31
+ from tme.analyzer import (
32
+ MaxScoreOverRotations,
33
+ PeakCallerMaximumFilter,
34
+ )
35
+ from tme.preprocessing import Compose
36
+ from tme.backends import backend
37
+
38
+
39
+ def get_func_fullname(func) -> str:
40
+ """Returns the full name of the given function, including its module."""
41
+ return f"<function '{func.__module__}.{func.__name__}'>"
42
+
43
+
44
+ def print_block(name: str, data: dict, label_width=20) -> None:
45
+ """Prints a formatted block of information."""
46
+ print(f"\n> {name}")
47
+ for key, value in data.items():
48
+ formatted_value = str(value)
49
+ print(f" - {key + ':':<{label_width}} {formatted_value}")
50
+
51
+
52
+ def print_entry() -> None:
53
+ width = 80
54
+ text = f" pyTME v{__version__} "
55
+ padding_total = width - len(text) - 2
56
+ padding_left = padding_total // 2
57
+ padding_right = padding_total - padding_left
58
+
59
+ print("*" * width)
60
+ print(f"*{ ' ' * padding_left }{text}{ ' ' * padding_right }*")
61
+ print("*" * width)
62
+
63
+
64
+ def check_positive(value):
65
+ ivalue = float(value)
66
+ if ivalue <= 0:
67
+ raise argparse.ArgumentTypeError("%s is an invalid positive float." % value)
68
+ return ivalue
69
+
70
+
71
+ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
72
+ """
73
+ Loadsa mask in CCP4/MRC format and assess whether the sampling_rate
74
+ and shape matches its target.
75
+
76
+ Parameters
77
+ ----------
78
+ mask_target : Density
79
+ Object the mask should be applied to
80
+ mask_path : str
81
+ Path to the mask in CCP4/MRC format.
82
+ kwargs : dict, optional
83
+ Keyword arguments passed to :py:meth:`tme.density.Density.from_file`.
84
+ Raise
85
+ -----
86
+ ValueError
87
+ If shape or sampling rate do not match between mask_target and mask
88
+
89
+ Returns
90
+ -------
91
+ Density
92
+ A density instance if the mask was validated and loaded otherwise None
93
+ """
94
+ mask = mask_path
95
+ if mask is not None:
96
+ mask = Density.from_file(mask, **kwargs)
97
+ mask.origin = deepcopy(mask_target.origin)
98
+ if not np.allclose(mask.shape, mask_target.shape):
99
+ raise ValueError(
100
+ f"Expected shape of {mask_path} was {mask_target.shape},"
101
+ f" got f{mask.shape}"
102
+ )
103
+ if not np.allclose(mask.sampling_rate, mask_target.sampling_rate):
104
+ raise ValueError(
105
+ f"Expected sampling_rate of {mask_path} was {mask_target.sampling_rate}"
106
+ f", got f{mask.sampling_rate}"
107
+ )
108
+ return mask
109
+
110
+
111
+ def crop_data(data: Density, cutoff: float, data_mask: Density = None) -> bool:
112
+ """
113
+ Crop the provided data and mask to a smaller box based on a cutoff value.
114
+
115
+ Parameters
116
+ ----------
117
+ data : Density
118
+ The data that should be cropped.
119
+ cutoff : float
120
+ The threshold value to determine which parts of the data should be kept.
121
+ data_mask : Density, optional
122
+ A mask for the data that should be cropped.
123
+
124
+ Returns
125
+ -------
126
+ bool
127
+ Returns True if the data was adjusted (cropped), otherwise returns False.
128
+
129
+ Notes
130
+ -----
131
+ Cropping is performed in place.
132
+ """
133
+ if cutoff is None:
134
+ return False
135
+
136
+ box = data.trim_box(cutoff=cutoff)
137
+ box_mask = box
138
+ if data_mask is not None:
139
+ box_mask = data_mask.trim_box(cutoff=cutoff)
140
+ box = tuple(
141
+ slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
142
+ for arr, mask in zip(box, box_mask)
143
+ )
144
+ if box == tuple(slice(0, x) for x in data.shape):
145
+ return False
146
+
147
+ data.adjust_box(box)
148
+
149
+ if data_mask:
150
+ data_mask.adjust_box(box)
151
+
152
+ return True
153
+
154
+
155
+ def parse_args():
156
+ parser = argparse.ArgumentParser(description="Perform template matching.")
157
+
158
+ io_group = parser.add_argument_group("Input / Output")
159
+ io_group.add_argument(
160
+ "-m",
161
+ "--target",
162
+ dest="target",
163
+ type=str,
164
+ required=True,
165
+ help="Path to a target in CCP4/MRC, EM, H5 or another format supported by "
166
+ "tme.density.Density.from_file "
167
+ "https://kosinskilab.github.io/pyTME/reference/api/tme.density.Density.from_file.html",
168
+ )
169
+ io_group.add_argument(
170
+ "--target_mask",
171
+ dest="target_mask",
172
+ type=str,
173
+ required=False,
174
+ help="Path to a mask for the target in a supported format (see target).",
175
+ )
176
+ io_group.add_argument(
177
+ "-i",
178
+ "--template",
179
+ dest="template",
180
+ type=str,
181
+ required=True,
182
+ help="Path to a template in PDB/MMCIF or other supported formats (see target).",
183
+ )
184
+ io_group.add_argument(
185
+ "--template_mask",
186
+ dest="template_mask",
187
+ type=str,
188
+ required=False,
189
+ help="Path to a mask for the template in a supported format (see target).",
190
+ )
191
+ io_group.add_argument(
192
+ "-o",
193
+ "--output",
194
+ dest="output",
195
+ type=str,
196
+ required=False,
197
+ default="output.pickle",
198
+ help="Path to the output pickle file.",
199
+ )
200
+ io_group.add_argument(
201
+ "--invert_target_contrast",
202
+ dest="invert_target_contrast",
203
+ action="store_true",
204
+ default=False,
205
+ help="Invert the target's contrast and rescale linearly between zero and one. "
206
+ "This option is intended for targets where templates to-be-matched have "
207
+ "negative values, e.g. tomograms.",
208
+ )
209
+ io_group.add_argument(
210
+ "--scramble_phases",
211
+ dest="scramble_phases",
212
+ action="store_true",
213
+ default=False,
214
+ help="Phase scramble the template to generate a noise score background.",
215
+ )
216
+
217
+ scoring_group = parser.add_argument_group("Scoring")
218
+ scoring_group.add_argument(
219
+ "-s",
220
+ dest="score",
221
+ type=str,
222
+ default="FLCSphericalMask",
223
+ choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
224
+ help="Template matching scoring function.",
225
+ )
226
+ scoring_group.add_argument(
227
+ "-a",
228
+ dest="angular_sampling",
229
+ type=check_positive,
230
+ default=40.0,
231
+ help="Angular sampling rate for template matching. "
232
+ "A lower number yields more rotations. Values >= 180 sample only the identity.",
233
+ )
234
+
235
+ computation_group = parser.add_argument_group("Computation")
236
+ computation_group.add_argument(
237
+ "-n",
238
+ dest="cores",
239
+ required=False,
240
+ type=int,
241
+ default=4,
242
+ help="Number of cores used for template matching.",
243
+ )
244
+ computation_group.add_argument(
245
+ "--use_gpu",
246
+ dest="use_gpu",
247
+ action="store_true",
248
+ default=False,
249
+ help="Whether to perform computations on the GPU.",
250
+ )
251
+ computation_group.add_argument(
252
+ "--gpu_indices",
253
+ dest="gpu_indices",
254
+ type=str,
255
+ default=None,
256
+ help="Comma-separated list of GPU indices to use. For example,"
257
+ " 0,1 for the first and second GPU. Only used if --use_gpu is set."
258
+ " If not provided but --use_gpu is set, CUDA_VISIBLE_DEVICES will"
259
+ " be respected.",
260
+ )
261
+ computation_group.add_argument(
262
+ "-r",
263
+ "--ram",
264
+ dest="memory",
265
+ required=False,
266
+ type=int,
267
+ default=None,
268
+ help="Amount of memory that can be used in bytes.",
269
+ )
270
+ computation_group.add_argument(
271
+ "--memory_scaling",
272
+ dest="memory_scaling",
273
+ required=False,
274
+ type=float,
275
+ default=0.85,
276
+ help="Fraction of available memory that can be used. Defaults to 0.85 and is "
277
+ "ignored if --ram is set",
278
+ )
279
+ computation_group.add_argument(
280
+ "--use_mixed_precision",
281
+ dest="use_mixed_precision",
282
+ action="store_true",
283
+ default=False,
284
+ help="Use float16 for real values operations where possible.",
285
+ )
286
+ computation_group.add_argument(
287
+ "--use_memmap",
288
+ dest="use_memmap",
289
+ action="store_true",
290
+ default=False,
291
+ help="Use memmaps to offload large data objects to disk. "
292
+ "Particularly useful for large inputs in combination with --use_gpu.",
293
+ )
294
+ computation_group.add_argument(
295
+ "--temp_directory",
296
+ dest="temp_directory",
297
+ default=None,
298
+ help="Directory for temporary objects. Faster I/O improves runtime.",
299
+ )
300
+
301
+ filter_group = parser.add_argument_group("Filters")
302
+ filter_group.add_argument(
303
+ "--lowpass",
304
+ dest="lowpass",
305
+ type=float,
306
+ required=False,
307
+ help="Resolution to lowpass filter template and target to in the same unit "
308
+ "as the sampling rate of template and target (typically Ångstrom).",
309
+ )
310
+ filter_group.add_argument(
311
+ "--highpass",
312
+ dest="highpass",
313
+ type=float,
314
+ required=False,
315
+ help="Resolution to highpass filter template and target to in the same unit "
316
+ "as the sampling rate of template and target (typically Ångstrom).",
317
+ )
318
+ filter_group.add_argument(
319
+ "--whiten_spectrum",
320
+ dest="whiten_spectrum",
321
+ action="store_true",
322
+ default=False,
323
+ help="Apply spectral whitening to template and target based on target spectrum.",
324
+ )
325
+ filter_group.add_argument(
326
+ "--wedge_axes",
327
+ dest="wedge_axes",
328
+ type=str,
329
+ required=False,
330
+ default="0,2",
331
+ help="Indices of wedge opening and tilt axis, e.g. 0,2 for a wedge that is open "
332
+ "in z-direction and tilted over the x axis.",
333
+ )
334
+ filter_group.add_argument(
335
+ "--tilt_angles",
336
+ dest="tilt_angles",
337
+ type=str,
338
+ required=False,
339
+ default=None,
340
+ help="Path to a file with angles and corresponding doses, or comma separated "
341
+ "start and stop stage tilt angle, e.g. 50,45, which yields a continuous wedge "
342
+ "mask. Alternatively, a tilt step size can be specified like 50,45:5.0 to "
343
+ "sample 5.0 degree tilt angle steps.",
344
+ )
345
+ filter_group.add_argument(
346
+ "--tilt_weighting",
347
+ dest="tilt_weighting",
348
+ type=str,
349
+ required=False,
350
+ choices=["angle", "relion", "grigorieff"],
351
+ default=None,
352
+ help="Weighting scheme used to reweight individual tilts. Available options: "
353
+ "angle (cosine based weighting), "
354
+ "relion (relion formalism for wedge weighting ),"
355
+ "grigorieff (exposure filter as defined in Grant and Grigorieff 2015)."
356
+ "",
357
+ )
358
+ filter_group.add_argument(
359
+ "--ctf_file",
360
+ dest="ctf_file",
361
+ type=str,
362
+ required=False,
363
+ default=None,
364
+ help="Path to a file with CTF parameters.",
365
+ )
366
+
367
+ performance_group = parser.add_argument_group("Performance")
368
+ performance_group.add_argument(
369
+ "--cutoff_target",
370
+ dest="cutoff_target",
371
+ type=float,
372
+ required=False,
373
+ default=None,
374
+ help="Target contour level (used for cropping).",
375
+ )
376
+ performance_group.add_argument(
377
+ "--cutoff_template",
378
+ dest="cutoff_template",
379
+ type=float,
380
+ required=False,
381
+ default=None,
382
+ help="Template contour level (used for cropping).",
383
+ )
384
+ performance_group.add_argument(
385
+ "--no_centering",
386
+ dest="no_centering",
387
+ action="store_true",
388
+ help="Assumes the template is already centered and omits centering.",
389
+ )
390
+ performance_group.add_argument(
391
+ "--no_edge_padding",
392
+ dest="no_edge_padding",
393
+ action="store_true",
394
+ default=False,
395
+ help="Whether to not pad the edges of the target. Can be set if the target"
396
+ " has a well defined bounding box, e.g. a masked reconstruction.",
397
+ )
398
+ performance_group.add_argument(
399
+ "--no_fourier_padding",
400
+ dest="no_fourier_padding",
401
+ action="store_true",
402
+ default=False,
403
+ help="Whether input arrays should not be zero-padded to full convolution shape "
404
+ "for numerical stability. When working with very large targets, e.g. tomograms, "
405
+ "it is safe to use this flag and benefit from the performance gain.",
406
+ )
407
+ performance_group.add_argument(
408
+ "--interpolation_order",
409
+ dest="interpolation_order",
410
+ required=False,
411
+ type=int,
412
+ default=3,
413
+ help="Spline interpolation used for template rotations. If less than zero "
414
+ "no interpolation is performed.",
415
+ )
416
+
417
+ analyzer_group = parser.add_argument_group("Analyzer")
418
+ analyzer_group.add_argument(
419
+ "--score_threshold",
420
+ dest="score_threshold",
421
+ required=False,
422
+ type=float,
423
+ default=0,
424
+ help="Minimum template matching scores to consider for analysis.",
425
+ )
426
+ analyzer_group.add_argument(
427
+ "-p",
428
+ dest="peak_calling",
429
+ action="store_true",
430
+ default=False,
431
+ help="Perform peak calling instead of score aggregation.",
432
+ )
433
+ args = parser.parse_args()
434
+
435
+ if args.interpolation_order < 0:
436
+ args.interpolation_order = None
437
+
438
+ if args.temp_directory is None:
439
+ default = abspath(".")
440
+ if os.environ.get("TMPDIR", None) is not None:
441
+ default = os.environ.get("TMPDIR")
442
+ args.temp_directory = default
443
+
444
+ os.environ["TMPDIR"] = args.temp_directory
445
+
446
+ args.pad_target_edges = not args.no_edge_padding
447
+ args.pad_fourier = not args.no_fourier_padding
448
+
449
+ if args.score not in MATCHING_EXHAUSTIVE_REGISTER:
450
+ raise ValueError(
451
+ f"score has to be one of {', '.join(MATCHING_EXHAUSTIVE_REGISTER.keys())}"
452
+ )
453
+
454
+ gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
455
+ if args.gpu_indices is not None:
456
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_indices
457
+
458
+ if args.use_gpu:
459
+ gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
460
+ if gpu_devices is None:
461
+ print(
462
+ "No GPU indices provided and CUDA_VISIBLE_DEVICES is not set.",
463
+ "Assuming device 0.",
464
+ )
465
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
466
+ args.gpu_indices = [
467
+ int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
468
+ ]
469
+
470
+ if args.wedge_axes is not None:
471
+ args.wedge_axes = [int(x) for x in args.wedge_axes.split(",")]
472
+
473
+ if args.tilt_angles is not None and args.wedge_axes is None:
474
+ raise ValueError("Wedge axes have to be specified with tilt angles.")
475
+
476
+ if args.ctf_file is not None and args.wedge_axes is None:
477
+ raise ValueError("Wedge axes have to be specified with CTF parameters.")
478
+ if args.ctf_file is not None and args.tilt_angles is None:
479
+ raise ValueError("Angles have to be specified with CTF parameters.")
480
+
481
+ return args
482
+
483
+
484
+ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
485
+ from tme.preprocessing import LinearWhiteningFilter, BandPassFilter
486
+
487
+ template_filter, target_filter = [], []
488
+ if args.tilt_angles is not None:
489
+ from tme.preprocessing.tilt_series import (
490
+ Wedge,
491
+ WedgeReconstructed,
492
+ ReconstructFromTilt,
493
+ )
494
+
495
+ try:
496
+ wedge = Wedge.from_file(args.tilt_angles)
497
+ wedge.weight_type = args.tilt_weighting
498
+ except FileNotFoundError:
499
+ tilt_step = None
500
+ tilt_start, tilt_stop = args.tilt_angles.split(",")
501
+ if ":" in tilt_stop:
502
+ tilt_stop, tilt_step = tilt_stop.split(":")
503
+ tilt_start, tilt_stop = float(tilt_start), float(tilt_stop)
504
+ tilt_angles = None
505
+ if tilt_step is not None:
506
+ tilt_step = float(tilt_step)
507
+ tilt_angles = np.arange(
508
+ -tilt_start, tilt_stop + tilt_step, tilt_step
509
+ ).tolist()
510
+ wedge = WedgeReconstructed(
511
+ angles=tilt_angles,
512
+ start_tilt=tilt_start,
513
+ stop_tilt=tilt_stop,
514
+ )
515
+ wedge.opening_axis = args.wedge_axes[0]
516
+ wedge.tilt_axis = args.wedge_axes[1]
517
+ wedge.sampling_rate = template.sampling_rate
518
+ template_filter.append(wedge)
519
+ if not isinstance(wedge, WedgeReconstructed):
520
+ template_filter.append(ReconstructFromTilt())
521
+
522
+ if args.ctf_file is not None:
523
+ from tme.preprocessing.tilt_series import CTF
524
+
525
+ ctf = CTF.from_file(args.ctf_file)
526
+ ctf.tilt_axis = args.wedge_axes[1]
527
+ ctf.opening_axis = args.wedge_axes[0]
528
+ template_filter.append(ctf)
529
+ if isinstance(template_filter[-1], ReconstructFromTilt):
530
+ template_filter.insert(-1, ctf)
531
+ else:
532
+ template_filter.insert(0, ctf)
533
+ template_filter.isnert(1, ReconstructFromTilt())
534
+
535
+ if args.lowpass or args.highpass is not None:
536
+ from tme.preprocessing import BandPassFilter
537
+
538
+ bandpass = BandPassFilter(
539
+ use_gaussian=True,
540
+ lowpass=args.lowpass,
541
+ highpass=args.highpass,
542
+ sampling_rate=template.sampling_rate,
543
+ )
544
+ template_filter.append(bandpass)
545
+ target_filter.append(bandpass)
546
+
547
+ if args.whiten_spectrum:
548
+ from tme.preprocessing import LinearWhiteningFilter
549
+
550
+ whitening_filter = LinearWhiteningFilter()
551
+ template_filter.append(whitening_filter)
552
+ target_filter.append(whitening_filter)
553
+
554
+ template_filter = Compose(template_filter) if len(template_filter) else None
555
+ target_filter = Compose(target_filter) if len(target_filter) else None
556
+
557
+ return template_filter, target_filter
558
+
559
+
560
+ def main():
561
+ args = parse_args()
562
+ print_entry()
563
+
564
+ target = Density.from_file(args.target, use_memmap=True)
565
+
566
+ try:
567
+ template = Density.from_file(args.template)
568
+ except Exception:
569
+ template = Density.from_structure(
570
+ filename_or_structure=args.template,
571
+ sampling_rate=target.sampling_rate,
572
+ )
573
+
574
+ if not np.allclose(target.sampling_rate, template.sampling_rate):
575
+ print(
576
+ f"Resampling template to {target.sampling_rate}. "
577
+ "Consider providing a template with the same sampling rate as the target."
578
+ )
579
+ template = template.resample(target.sampling_rate, order=3)
580
+
581
+ template_mask = load_and_validate_mask(
582
+ mask_target=template, mask_path=args.template_mask
583
+ )
584
+ target_mask = load_and_validate_mask(
585
+ mask_target=target, mask_path=args.target_mask, use_memmap=True
586
+ )
587
+
588
+ initial_shape = target.shape
589
+ is_cropped = crop_data(
590
+ data=target, data_mask=target_mask, cutoff=args.cutoff_target
591
+ )
592
+ print_block(
593
+ name="Target",
594
+ data={
595
+ "Initial Shape": initial_shape,
596
+ "Sampling Rate": tuple(np.round(target.sampling_rate, 2)),
597
+ "Final Shape": target.shape,
598
+ },
599
+ )
600
+ if is_cropped:
601
+ args.target = generate_tempfile_name(suffix=".mrc")
602
+ target.to_file(args.target)
603
+
604
+ if target_mask:
605
+ args.target_mask = generate_tempfile_name(suffix=".mrc")
606
+ target_mask.to_file(args.target_mask)
607
+
608
+ if target_mask:
609
+ print_block(
610
+ name="Target Mask",
611
+ data={
612
+ "Initial Shape": initial_shape,
613
+ "Sampling Rate": tuple(np.round(target_mask.sampling_rate, 2)),
614
+ "Final Shape": target_mask.shape,
615
+ },
616
+ )
617
+
618
+ initial_shape = template.shape
619
+ _ = crop_data(data=template, data_mask=template_mask, cutoff=args.cutoff_template)
620
+
621
+ translation = np.zeros(len(template.shape), dtype=np.float32)
622
+ if not args.no_centering:
623
+ template, translation = template.centered(0)
624
+ print_block(
625
+ name="Template",
626
+ data={
627
+ "Initial Shape": initial_shape,
628
+ "Sampling Rate": tuple(np.round(template.sampling_rate, 2)),
629
+ "Final Shape": template.shape,
630
+ },
631
+ )
632
+
633
+ if template_mask is None:
634
+ template_mask = template.empty
635
+ if not args.no_centering:
636
+ enclosing_box = template.minimum_enclosing_box(
637
+ 0, use_geometric_center=False
638
+ )
639
+ template_mask.adjust_box(enclosing_box)
640
+
641
+ template_mask.data[:] = 1
642
+ translation = np.zeros_like(translation)
643
+
644
+ template_mask.pad(template.shape, center=False)
645
+ origin_translation = np.divide(
646
+ np.subtract(template.origin, template_mask.origin), template.sampling_rate
647
+ )
648
+ translation = np.add(translation, origin_translation)
649
+
650
+ template_mask = template_mask.rigid_transform(
651
+ rotation_matrix=np.eye(template_mask.data.ndim),
652
+ translation=-translation,
653
+ order=1,
654
+ )
655
+ template_mask.origin = template.origin.copy()
656
+ print_block(
657
+ name="Template Mask",
658
+ data={
659
+ "Inital Shape": initial_shape,
660
+ "Sampling Rate": tuple(np.round(template_mask.sampling_rate, 2)),
661
+ "Final Shape": template_mask.shape,
662
+ },
663
+ )
664
+ print("\n" + "-" * 80)
665
+
666
+ if args.scramble_phases:
667
+ template.data = scramble_phases(
668
+ template.data, noise_proportion=1.0, normalize_power=True
669
+ )
670
+
671
+ available_memory = backend.get_available_memory()
672
+ if args.use_gpu:
673
+ args.cores = len(args.gpu_indices)
674
+ has_torch = importlib.util.find_spec("torch") is not None
675
+ has_cupy = importlib.util.find_spec("cupy") is not None
676
+
677
+ if not has_torch and not has_cupy:
678
+ raise ValueError(
679
+ "Found neither CuPy nor PyTorch installation. You need to install"
680
+ " either to enable GPU support."
681
+ )
682
+
683
+ if args.peak_calling:
684
+ preferred_backend = "pytorch"
685
+ if not has_torch:
686
+ preferred_backend = "cupy"
687
+ backend.change_backend(backend_name=preferred_backend, device="cuda")
688
+ else:
689
+ preferred_backend = "cupy"
690
+ if not has_cupy:
691
+ preferred_backend = "pytorch"
692
+ backend.change_backend(backend_name=preferred_backend, device="cuda")
693
+ if args.use_mixed_precision and preferred_backend == "pytorch":
694
+ raise NotImplementedError(
695
+ "pytorch backend does not yet support mixed precision."
696
+ " Consider installing CuPy to enable this feature."
697
+ )
698
+ elif args.use_mixed_precision:
699
+ backend.change_backend(
700
+ backend_name="cupy",
701
+ default_dtype=backend._array_backend.float16,
702
+ complex_dtype=backend._array_backend.complex64,
703
+ default_dtype_int=backend._array_backend.int16,
704
+ )
705
+ available_memory = backend.get_available_memory() * args.cores
706
+ if preferred_backend == "pytorch" and args.interpolation_order == 3:
707
+ args.interpolation_order = 1
708
+
709
+ if args.memory is None:
710
+ args.memory = int(args.memory_scaling * available_memory)
711
+
712
+ target_padding = np.zeros_like(template.shape)
713
+ if args.pad_target_edges:
714
+ target_padding = template.shape
715
+
716
+ template_box = template.shape
717
+ if not args.pad_fourier:
718
+ template_box = np.ones(len(template_box), dtype=int)
719
+
720
+ callback_class = MaxScoreOverRotations
721
+ if args.peak_calling:
722
+ callback_class = PeakCallerMaximumFilter
723
+
724
+ splits, schedule = compute_parallelization_schedule(
725
+ shape1=target.shape,
726
+ shape2=template_box,
727
+ shape1_padding=target_padding,
728
+ max_cores=args.cores,
729
+ max_ram=args.memory,
730
+ split_only_outer=args.use_gpu,
731
+ matching_method=args.score,
732
+ analyzer_method=callback_class.__name__,
733
+ backend=backend._backend_name,
734
+ float_nbytes=backend.datatype_bytes(backend._default_dtype),
735
+ complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
736
+ integer_nbytes=backend.datatype_bytes(backend._default_dtype_int),
737
+ )
738
+
739
+ if splits is None:
740
+ print(
741
+ "Found no suitable parallelization schedule. Consider increasing"
742
+ " available RAM or decreasing number of cores."
743
+ )
744
+ exit(-1)
745
+
746
+ analyzer_args = {
747
+ "score_threshold": args.score_threshold,
748
+ "number_of_peaks": 1000,
749
+ "convolution_mode": "valid",
750
+ "use_memmap": args.use_memmap,
751
+ }
752
+
753
+ matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
754
+ matching_data = MatchingData(target=target, template=template.data)
755
+ matching_data.rotations = get_rotation_matrices(
756
+ angular_sampling=args.angular_sampling, dim=target.data.ndim
757
+ )
758
+ if args.angular_sampling >= 180:
759
+ ndim = target.data.ndim
760
+ matching_data.rotations = np.eye(ndim).reshape(1, ndim, ndim)
761
+
762
+ template_filter, target_filter = setup_filter(args, template, target)
763
+ matching_data.template_filter = template_filter
764
+ matching_data.target_filter = target_filter
765
+
766
+ matching_data._invert_target = args.invert_target_contrast
767
+ if target_mask is not None:
768
+ matching_data.target_mask = target_mask
769
+ if template_mask is not None:
770
+ matching_data.template_mask = template_mask.data
771
+
772
+ n_splits = np.prod(list(splits.values()))
773
+ target_split = ", ".join(
774
+ [":".join([str(x) for x in axis]) for axis in splits.items()]
775
+ )
776
+ gpus_used = 0 if args.gpu_indices is None else len(args.gpu_indices)
777
+ options = {
778
+ "CPU Cores": args.cores,
779
+ "Run on GPU": f"{args.use_gpu} [N={gpus_used}]",
780
+ "Use Mixed Precision": args.use_mixed_precision,
781
+ "Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
782
+ "Temporary Directory": args.temp_directory,
783
+ "Extend Fourier Grid": not args.no_fourier_padding,
784
+ "Extend Target Edges": not args.no_edge_padding,
785
+ "Interpolation Order": args.interpolation_order,
786
+ "Score": f"{args.score}",
787
+ "Setup Function": f"{get_func_fullname(matching_setup)}",
788
+ "Scoring Function": f"{get_func_fullname(matching_score)}",
789
+ "Angular Sampling": f"{args.angular_sampling}"
790
+ f" [{matching_data.rotations.shape[0]} rotations]",
791
+ "Scramble Template": args.scramble_phases,
792
+ "Target Splits": f"{target_split} [N={n_splits}]",
793
+ }
794
+
795
+ print_block(
796
+ name="Template Matching Options",
797
+ data=options,
798
+ label_width=max(len(key) for key in options.keys()) + 2,
799
+ )
800
+
801
+ options = {"Analyzer": callback_class, **analyzer_args}
802
+ print_block(
803
+ name="Score Analysis Options",
804
+ data=options,
805
+ label_width=max(len(key) for key in options.keys()) + 2,
806
+ )
807
+ print("\n" + "-" * 80)
808
+
809
+ outer_jobs = f"{schedule[0]} job{'s' if schedule[0] > 1 else ''}"
810
+ inner_jobs = f"{schedule[1]} core{'s' if schedule[1] > 1 else ''}"
811
+ n_splits = f"{n_splits} split{'s' if n_splits > 1 else ''}"
812
+ print(f"\nDistributing {n_splits} on {outer_jobs} each using {inner_jobs}.")
813
+
814
+ start = time()
815
+ print("Running Template Matching. This might take a while ...")
816
+ candidates = scan_subsets(
817
+ matching_data=matching_data,
818
+ job_schedule=schedule,
819
+ matching_score=matching_score,
820
+ matching_setup=matching_setup,
821
+ callback_class=callback_class,
822
+ callback_class_args=analyzer_args,
823
+ target_splits=splits,
824
+ pad_target_edges=args.pad_target_edges,
825
+ pad_fourier=args.pad_fourier,
826
+ interpolation_order=args.interpolation_order,
827
+ )
828
+
829
+ candidates = list(candidates) if candidates is not None else []
830
+ if callback_class == MaxScoreOverRotations:
831
+ if target_mask is not None and args.score != "MCC":
832
+ candidates[0] *= target_mask.data
833
+ with warnings.catch_warnings():
834
+ warnings.simplefilter("ignore", category=UserWarning)
835
+ candidates[3] = {
836
+ x: euler_from_rotationmatrix(
837
+ np.frombuffer(i, dtype=matching_data.rotations.dtype).reshape(
838
+ candidates[0].ndim, candidates[0].ndim
839
+ )
840
+ )
841
+ for i, x in candidates[3].items()
842
+ }
843
+
844
+ candidates.append((target.origin, template.origin, target.sampling_rate, args))
845
+ write_pickle(data=candidates, filename=args.output)
846
+
847
+ runtime = time() - start
848
+ print(f"\nRuntime real: {runtime:.3f}s user: {(runtime * args.cores):.3f}s.")
849
+
850
+
851
+ if __name__ == "__main__":
852
+ main()