pytme 0.2.9.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. pytme-0.3.0.data/scripts/estimate_memory_usage.py +76 -0
  2. pytme-0.3.0.data/scripts/match_template.py +1106 -0
  3. {pytme-0.2.9.post1.data → pytme-0.3.0.data}/scripts/postprocess.py +320 -190
  4. {pytme-0.2.9.post1.data → pytme-0.3.0.data}/scripts/preprocess.py +21 -31
  5. {pytme-0.2.9.post1.data → pytme-0.3.0.data}/scripts/preprocessor_gui.py +85 -19
  6. pytme-0.3.0.data/scripts/pytme_runner.py +771 -0
  7. {pytme-0.2.9.post1.dist-info → pytme-0.3.0.dist-info}/METADATA +21 -20
  8. pytme-0.3.0.dist-info/RECORD +126 -0
  9. {pytme-0.2.9.post1.dist-info → pytme-0.3.0.dist-info}/entry_points.txt +2 -1
  10. pytme-0.3.0.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +76 -0
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +224 -0
  14. scripts/match_template.py +349 -378
  15. pytme-0.2.9.post1.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
  16. scripts/postprocess.py +320 -190
  17. scripts/preprocess.py +21 -31
  18. scripts/preprocessor_gui.py +85 -19
  19. scripts/pytme_runner.py +771 -0
  20. scripts/refine_matches.py +625 -0
  21. tests/preprocessing/test_frequency_filters.py +28 -14
  22. tests/test_analyzer.py +41 -36
  23. tests/test_backends.py +1 -0
  24. tests/test_matching_cli.py +109 -54
  25. tests/test_matching_data.py +5 -5
  26. tests/test_matching_exhaustive.py +1 -2
  27. tests/test_matching_optimization.py +4 -9
  28. tests/test_matching_utils.py +1 -1
  29. tests/test_orientations.py +0 -1
  30. tme/__version__.py +1 -1
  31. tme/analyzer/__init__.py +2 -0
  32. tme/analyzer/_utils.py +26 -21
  33. tme/analyzer/aggregation.py +395 -222
  34. tme/analyzer/base.py +127 -0
  35. tme/analyzer/peaks.py +189 -204
  36. tme/analyzer/proxy.py +123 -0
  37. tme/backends/__init__.py +4 -3
  38. tme/backends/_cupy_utils.py +25 -24
  39. tme/backends/_jax_utils.py +20 -18
  40. tme/backends/cupy_backend.py +13 -26
  41. tme/backends/jax_backend.py +24 -23
  42. tme/backends/matching_backend.py +4 -3
  43. tme/backends/mlx_backend.py +4 -3
  44. tme/backends/npfftw_backend.py +34 -30
  45. tme/backends/pytorch_backend.py +18 -4
  46. tme/cli.py +126 -0
  47. tme/density.py +9 -7
  48. tme/extensions.cpython-311-darwin.so +0 -0
  49. tme/filters/__init__.py +3 -3
  50. tme/filters/_utils.py +36 -10
  51. tme/filters/bandpass.py +229 -188
  52. tme/filters/compose.py +5 -4
  53. tme/filters/ctf.py +516 -254
  54. tme/filters/reconstruction.py +91 -32
  55. tme/filters/wedge.py +196 -135
  56. tme/filters/whitening.py +37 -42
  57. tme/matching_data.py +28 -39
  58. tme/matching_exhaustive.py +31 -27
  59. tme/matching_optimization.py +5 -4
  60. tme/matching_scores.py +25 -15
  61. tme/matching_utils.py +193 -27
  62. tme/memory.py +4 -3
  63. tme/orientations.py +22 -9
  64. tme/parser.py +114 -33
  65. tme/preprocessor.py +6 -5
  66. tme/rotations.py +10 -7
  67. tme/structure.py +4 -3
  68. pytme-0.2.9.post1.data/scripts/estimate_ram_usage.py +0 -97
  69. pytme-0.2.9.post1.dist-info/RECORD +0 -119
  70. pytme-0.2.9.post1.dist-info/licenses/LICENSE +0 -153
  71. scripts/estimate_ram_usage.py +0 -97
  72. tests/data/Maps/.DS_Store +0 -0
  73. tests/data/Structures/.DS_Store +0 -0
  74. {pytme-0.2.9.post1.dist-info → pytme-0.3.0.dist-info}/WHEEL +0 -0
  75. {pytme-0.2.9.post1.dist-info → pytme-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1106 @@
1
+ #!python
2
+ """CLI for basic pyTME template matching functions.
3
+
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
+
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ """
8
+ import os
9
+ import argparse
10
+ import warnings
11
+ from sys import exit
12
+ from time import time
13
+ from typing import Tuple
14
+ from copy import deepcopy
15
+ from tempfile import gettempdir
16
+ from os.path import exists, abspath
17
+
18
+ import numpy as np
19
+
20
+ from tme.backends import backend as be
21
+ from tme import Density, __version__, Orientations
22
+ from tme.matching_utils import scramble_phases, write_pickle
23
+ from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
24
+ from tme.rotations import (
25
+ get_cone_rotations,
26
+ get_rotation_matrices,
27
+ euler_to_rotationmatrix,
28
+ )
29
+ from tme.matching_data import MatchingData
30
+ from tme.analyzer import (
31
+ MaxScoreOverRotations,
32
+ PeakCallerMaximumFilter,
33
+ MaxScoreOverRotationsConstrained,
34
+ )
35
+ from tme.filters import (
36
+ CTF,
37
+ Wedge,
38
+ Compose,
39
+ BandPass,
40
+ CTFReconstructed,
41
+ WedgeReconstructed,
42
+ ReconstructFromTilt,
43
+ LinearWhiteningFilter,
44
+ BandPassReconstructed,
45
+ )
46
+ from tme.cli import (
47
+ get_func_fullname,
48
+ print_block,
49
+ print_entry,
50
+ check_positive,
51
+ sanitize_name,
52
+ )
53
+
54
+
55
+ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
56
+ """
57
+ Loadsa mask in CCP4/MRC format and assess whether the sampling_rate
58
+ and shape matches its target.
59
+
60
+ Parameters
61
+ ----------
62
+ mask_target : Density
63
+ Object the mask should be applied to
64
+ mask_path : str
65
+ Path to the mask in CCP4/MRC format.
66
+ kwargs : dict, optional
67
+ Keyword arguments passed to :py:meth:`tme.density.Density.from_file`.
68
+ Raise
69
+ -----
70
+ ValueError
71
+ If shape or sampling rate do not match between mask_target and mask
72
+
73
+ Returns
74
+ -------
75
+ Density
76
+ A density instance if the mask was validated and loaded otherwise None
77
+ """
78
+ mask = mask_path
79
+ if mask is not None:
80
+ mask = Density.from_file(mask, **kwargs)
81
+ mask.origin = deepcopy(mask_target.origin)
82
+ if not np.allclose(mask.shape, mask_target.shape):
83
+ raise ValueError(
84
+ f"Expected shape of {mask_path} was {mask_target.shape},"
85
+ f" got f{mask.shape}"
86
+ )
87
+ if not np.allclose(
88
+ np.round(mask.sampling_rate, 2), np.round(mask_target.sampling_rate, 2)
89
+ ):
90
+ raise ValueError(
91
+ f"Expected sampling_rate of {mask_path} was {mask_target.sampling_rate}"
92
+ f", got f{mask.sampling_rate}"
93
+ )
94
+ return mask
95
+
96
+
97
+ def parse_rotation_logic(args, ndim):
98
+ if args.particle_diameter is not None:
99
+ resolution = Density.from_file(args.target, use_memmap=True)
100
+ resolution = 360 * np.maximum(
101
+ np.max(2 * resolution.sampling_rate),
102
+ args.lowpass if args.lowpass is not None else 0,
103
+ )
104
+ args.angular_sampling = resolution / (3.14159265358979 * args.particle_diameter)
105
+
106
+ if args.angular_sampling is not None:
107
+ rotations = get_rotation_matrices(
108
+ angular_sampling=args.angular_sampling,
109
+ dim=ndim,
110
+ use_optimized_set=not args.no_use_optimized_set,
111
+ )
112
+ if args.angular_sampling >= 180:
113
+ rotations = np.eye(ndim).reshape(1, ndim, ndim)
114
+ return rotations
115
+
116
+ if args.axis_sampling is None:
117
+ args.axis_sampling = args.cone_sampling
118
+
119
+ rotations = get_cone_rotations(
120
+ cone_angle=args.cone_angle,
121
+ cone_sampling=args.cone_sampling,
122
+ axis_angle=args.axis_angle,
123
+ axis_sampling=args.axis_sampling,
124
+ n_symmetry=args.axis_symmetry,
125
+ axis=[0 if i != args.cone_axis else 1 for i in range(ndim)],
126
+ reference=[0, 0, -1 if args.invert_cone else 1],
127
+ )
128
+ return rotations
129
+
130
+
131
+ def compute_schedule(
132
+ args,
133
+ matching_data: MatchingData,
134
+ callback_class,
135
+ pad_edges: bool = False,
136
+ ):
137
+ # User requested target padding
138
+ if args.pad_edges is True:
139
+ pad_edges = True
140
+
141
+ splits, schedule = matching_data.computation_schedule(
142
+ matching_method=args.score,
143
+ analyzer_method=callback_class.__name__,
144
+ use_gpu=args.use_gpu,
145
+ pad_fourier=False,
146
+ pad_target_edges=pad_edges,
147
+ available_memory=args.memory,
148
+ max_cores=args.cores,
149
+ )
150
+
151
+ if splits is None:
152
+ print(
153
+ "Found no suitable parallelization schedule. Consider increasing"
154
+ " available RAM or decreasing number of cores."
155
+ )
156
+ exit(-1)
157
+
158
+ n_splits = np.prod(list(splits.values()))
159
+ if pad_edges is False and len(matching_data._target_dim) == 0 and n_splits > 1:
160
+ args.pad_edges = True
161
+ return compute_schedule(args, matching_data, callback_class, True)
162
+ return splits, schedule
163
+
164
+
165
+ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
166
+ template_filter, target_filter = [], []
167
+
168
+ if args.tilt_angles is None:
169
+ args.tilt_angles = args.ctf_file
170
+
171
+ wedge = None
172
+ if args.tilt_angles is not None:
173
+ try:
174
+ wedge = Wedge.from_file(args.tilt_angles)
175
+ wedge.weight_type = args.tilt_weighting
176
+ if args.tilt_weighting in ("angle", None):
177
+ wedge = WedgeReconstructed(
178
+ angles=wedge.angles,
179
+ weight_wedge=args.tilt_weighting == "angle",
180
+ )
181
+ except (FileNotFoundError, AttributeError):
182
+ tilt_start, tilt_stop = args.tilt_angles.split(",")
183
+ tilt_start, tilt_stop = abs(float(tilt_start)), abs(float(tilt_stop))
184
+ wedge = WedgeReconstructed(
185
+ angles=(tilt_start, tilt_stop),
186
+ create_continuous_wedge=True,
187
+ weight_wedge=False,
188
+ reconstruction_filter=args.reconstruction_filter,
189
+ )
190
+ wedge.opening_axis, wedge.tilt_axis = args.wedge_axes
191
+
192
+ wedge_target = WedgeReconstructed(
193
+ angles=wedge.angles,
194
+ weight_wedge=False,
195
+ create_continuous_wedge=True,
196
+ opening_axis=wedge.opening_axis,
197
+ tilt_axis=wedge.tilt_axis,
198
+ )
199
+
200
+ wedge.sampling_rate = template.sampling_rate
201
+ wedge_target.sampling_rate = template.sampling_rate
202
+
203
+ target_filter.append(wedge_target)
204
+ template_filter.append(wedge)
205
+
206
+ if args.ctf_file is not None or args.defocus is not None:
207
+ try:
208
+ ctf = CTF.from_file(
209
+ args.ctf_file,
210
+ spherical_aberration=args.spherical_aberration,
211
+ amplitude_contrast=args.amplitude_contrast,
212
+ acceleration_voltage=args.acceleration_voltage * 1e3,
213
+ )
214
+ if (len(ctf.angles) == 0) and wedge is None:
215
+ raise ValueError(
216
+ "You requested to specify the CTF per tilt, but did not specify "
217
+ "tilt angles via --tilt-angles or --ctf-file. "
218
+ )
219
+ if len(ctf.angles) == 0:
220
+ ctf.angles = wedge.angles
221
+
222
+ n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
223
+ if (n_tilts_ctfs != n_tils_angles) and isinstance(wedge, Wedge):
224
+ raise ValueError(
225
+ f"CTF file contains {n_tilts_ctfs} tilt, but recieved "
226
+ f"{n_tils_angles} tilt angles. Expected one angle per tilt"
227
+ )
228
+
229
+ except (FileNotFoundError, AttributeError):
230
+ ctf = CTFReconstructed(
231
+ defocus_x=args.defocus,
232
+ phase_shift=args.phase_shift,
233
+ amplitude_contrast=args.amplitude_contrast,
234
+ spherical_aberration=args.spherical_aberration,
235
+ acceleration_voltage=args.acceleration_voltage * 1e3,
236
+ )
237
+ ctf.flip_phase = args.no_flip_phase
238
+ ctf.sampling_rate = template.sampling_rate
239
+ ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
240
+ ctf.correct_defocus_gradient = args.correct_defocus_gradient
241
+ template_filter.append(ctf)
242
+
243
+ if args.lowpass or args.highpass is not None:
244
+ lowpass, highpass = args.lowpass, args.highpass
245
+ if args.pass_format == "voxel":
246
+ if lowpass is not None:
247
+ lowpass = np.max(np.multiply(lowpass, template.sampling_rate))
248
+ if highpass is not None:
249
+ highpass = np.max(np.multiply(highpass, template.sampling_rate))
250
+ elif args.pass_format == "frequency":
251
+ if lowpass is not None:
252
+ lowpass = np.max(np.divide(template.sampling_rate, lowpass))
253
+ if highpass is not None:
254
+ highpass = np.max(np.divide(template.sampling_rate, highpass))
255
+
256
+ try:
257
+ if args.lowpass >= args.highpass:
258
+ warnings.warn("--lowpass should be smaller than --highpass.")
259
+ except Exception:
260
+ pass
261
+
262
+ bandpass = BandPassReconstructed(
263
+ use_gaussian=args.no_pass_smooth,
264
+ lowpass=lowpass,
265
+ highpass=highpass,
266
+ sampling_rate=template.sampling_rate,
267
+ )
268
+ template_filter.append(bandpass)
269
+ target_filter.append(bandpass)
270
+
271
+ if args.whiten_spectrum:
272
+ whitening_filter = LinearWhiteningFilter()
273
+ template_filter.append(whitening_filter)
274
+ target_filter.append(whitening_filter)
275
+
276
+ rec_filt = (Wedge, CTF)
277
+ needs_reconstruction = sum(type(x) in rec_filt for x in template_filter)
278
+ if needs_reconstruction > 0 and args.reconstruction_filter is None:
279
+ warnings.warn(
280
+ "Consider using a --reconstruction_filter such as 'ram-lak' or 'ramp' "
281
+ "to avoid artifacts from reconstruction using weighted backprojection."
282
+ )
283
+
284
+ template_filter = sorted(
285
+ template_filter, key=lambda x: type(x) in rec_filt, reverse=True
286
+ )
287
+ if needs_reconstruction > 0:
288
+ relevant_filters = [x for x in template_filter if type(x) in rec_filt]
289
+ if len(relevant_filters) == 0:
290
+ raise ValueError("Filters require ")
291
+
292
+ reconstruction_filter = ReconstructFromTilt(
293
+ reconstruction_filter=args.reconstruction_filter,
294
+ interpolation_order=args.reconstruction_interpolation_order,
295
+ angles=relevant_filters[0].angles,
296
+ opening_axis=args.wedge_axes[0],
297
+ tilt_axis=args.wedge_axes[1],
298
+ )
299
+ template_filter.insert(needs_reconstruction, reconstruction_filter)
300
+
301
+ template_filter = Compose(template_filter) if len(template_filter) else None
302
+ target_filter = Compose(target_filter) if len(target_filter) else None
303
+ if args.no_filter_target:
304
+ target_filter = None
305
+
306
+ return template_filter, target_filter
307
+
308
+
309
+ def _format_sampling(arr, decimals: int = 2):
310
+ return tuple(round(float(x), decimals) for x in arr)
311
+
312
+
313
+ def parse_args():
314
+ parser = argparse.ArgumentParser(
315
+ description="Perform template matching.",
316
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
317
+ )
318
+
319
+ io_group = parser.add_argument_group("Input / Output")
320
+ io_group.add_argument(
321
+ "-m",
322
+ "--target",
323
+ type=str,
324
+ required=True,
325
+ help="Path to a target in CCP4/MRC, EM, H5 or another format supported by "
326
+ "tme.density.Density.from_file "
327
+ "https://kosinskilab.github.io/pyTME/reference/api/tme.density.Density.from_file.html",
328
+ )
329
+ io_group.add_argument(
330
+ "-M",
331
+ "--target-mask",
332
+ type=str,
333
+ required=False,
334
+ help="Path to a mask for the target in a supported format (see target).",
335
+ )
336
+ io_group.add_argument(
337
+ "-i",
338
+ "--template",
339
+ type=str,
340
+ required=True,
341
+ help="Path to a template in PDB/MMCIF or other supported formats (see target).",
342
+ )
343
+ io_group.add_argument(
344
+ "-I",
345
+ "--template-mask",
346
+ type=str,
347
+ required=False,
348
+ help="Path to a mask for the template in a supported format (see target).",
349
+ )
350
+ io_group.add_argument(
351
+ "-o",
352
+ "--output",
353
+ type=str,
354
+ required=False,
355
+ default="output.pickle",
356
+ help="Path to the output pickle file.",
357
+ )
358
+ io_group.add_argument(
359
+ "--invert-target-contrast",
360
+ action="store_true",
361
+ default=False,
362
+ help="Invert the target's contrast for cases where templates to-be-matched have "
363
+ "negative values, e.g. tomograms.",
364
+ )
365
+ io_group.add_argument(
366
+ "--scramble-phases",
367
+ action="store_true",
368
+ default=False,
369
+ help="Phase scramble the template to generate a noise score background.",
370
+ )
371
+
372
+ sampling_group = parser.add_argument_group("Sampling")
373
+ sampling_group.add_argument(
374
+ "--orientations",
375
+ default=None,
376
+ required=False,
377
+ help="Path to a file readable by Orientations.from_file containing "
378
+ "translations and rotations of candidate peaks to refine.",
379
+ )
380
+ sampling_group.add_argument(
381
+ "--orientations-scaling",
382
+ required=False,
383
+ type=float,
384
+ default=1.0,
385
+ help="Scaling factor to map candidate translations onto the target. "
386
+ "Assuming coordinates are in Å and target sampling rate are 3Å/voxel, "
387
+ "the corresponding --orientations-scaling would be 3.",
388
+ )
389
+ sampling_group.add_argument(
390
+ "--orientations-cone",
391
+ required=False,
392
+ type=float,
393
+ default=20.0,
394
+ help="Accept orientations within specified cone angle of each orientation.",
395
+ )
396
+ sampling_group.add_argument(
397
+ "--orientations-uncertainty",
398
+ required=False,
399
+ type=str,
400
+ default="10",
401
+ help="Accept translations within the specified radius of each orientation. "
402
+ "Can be a single value or comma-separated string for per-axis uncertainty.",
403
+ )
404
+
405
+ scoring_group = parser.add_argument_group("Scoring")
406
+ scoring_group.add_argument(
407
+ "-s",
408
+ "--score",
409
+ type=str,
410
+ default="FLCSphericalMask",
411
+ choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
412
+ help="Template matching scoring function.",
413
+ )
414
+
415
+ angular_group = parser.add_argument_group("Angular Sampling")
416
+ angular_exclusive = angular_group.add_mutually_exclusive_group(required=True)
417
+
418
+ angular_exclusive.add_argument(
419
+ "-a",
420
+ "--angular-sampling",
421
+ type=check_positive,
422
+ default=None,
423
+ help="Angular sampling rate. Lower values = more rotations, higher precision.",
424
+ )
425
+ angular_exclusive.add_argument(
426
+ "--cone-angle",
427
+ type=check_positive,
428
+ default=None,
429
+ help="Half-angle of the cone to be sampled in degrees. Allows to sample a "
430
+ "narrow interval around a known orientation, e.g. for surface oversampling.",
431
+ )
432
+ angular_exclusive.add_argument(
433
+ "--particle-diameter",
434
+ type=check_positive,
435
+ default=None,
436
+ help="Particle diameter in units of sampling rate.",
437
+ )
438
+ angular_group.add_argument(
439
+ "--cone-axis",
440
+ type=check_positive,
441
+ default=2,
442
+ help="Principal axis to build cone around.",
443
+ )
444
+ angular_group.add_argument(
445
+ "--invert-cone",
446
+ action="store_true",
447
+ help="Invert cone handedness.",
448
+ )
449
+ angular_group.add_argument(
450
+ "--cone-sampling",
451
+ type=check_positive,
452
+ default=None,
453
+ help="Sampling rate of the cone in degrees.",
454
+ )
455
+ angular_group.add_argument(
456
+ "--axis-angle",
457
+ type=check_positive,
458
+ default=360.0,
459
+ required=False,
460
+ help="Sampling angle along the z-axis of the cone.",
461
+ )
462
+ angular_group.add_argument(
463
+ "--axis-sampling",
464
+ type=check_positive,
465
+ default=None,
466
+ required=False,
467
+ help="Sampling rate along the z-axis of the cone. Defaults to --cone-sampling.",
468
+ )
469
+ angular_group.add_argument(
470
+ "--axis-symmetry",
471
+ type=check_positive,
472
+ default=1,
473
+ required=False,
474
+ help="N-fold symmetry around z-axis of the cone.",
475
+ )
476
+ angular_group.add_argument(
477
+ "--no-use-optimized-set",
478
+ action="store_true",
479
+ default=False,
480
+ required=False,
481
+ help="Whether to use random uniform instead of optimized rotation sets.",
482
+ )
483
+
484
+ computation_group = parser.add_argument_group("Computation")
485
+ computation_group.add_argument(
486
+ "-n",
487
+ dest="cores",
488
+ required=False,
489
+ type=int,
490
+ default=4,
491
+ help="Number of cores used for template matching.",
492
+ )
493
+ computation_group.add_argument(
494
+ "--gpu-indices",
495
+ type=str,
496
+ default=None,
497
+ help="Comma-separated GPU indices (e.g., '0,1,2' for first 3 GPUs). Otherwise "
498
+ "CUDA_VISIBLE_DEVICES will be used.",
499
+ )
500
+ computation_group.add_argument(
501
+ "--memory",
502
+ required=False,
503
+ type=int,
504
+ default=None,
505
+ help="Amount of memory that can be used in bytes.",
506
+ )
507
+ computation_group.add_argument(
508
+ "--memory-scaling",
509
+ required=False,
510
+ type=float,
511
+ default=0.85,
512
+ help="Fraction of available memory to be used. Ignored if --memory is set.",
513
+ )
514
+ computation_group.add_argument(
515
+ "--temp-directory",
516
+ default=gettempdir(),
517
+ help="Temporary directory for memmaps. Better I/O improves runtime.",
518
+ )
519
+ computation_group.add_argument(
520
+ "--backend",
521
+ default=be._backend_name,
522
+ choices=be.available_backends(),
523
+ help="Set computation backend.",
524
+ )
525
+ filter_group = parser.add_argument_group("Filters")
526
+ filter_group.add_argument(
527
+ "--lowpass",
528
+ type=float,
529
+ required=False,
530
+ help="Resolution to lowpass filter template and target to in the same unit "
531
+ "as the sampling rate of template and target (typically Ångstrom).",
532
+ )
533
+ filter_group.add_argument(
534
+ "--highpass",
535
+ type=float,
536
+ required=False,
537
+ help="Resolution to highpass filter template and target to in the same unit "
538
+ "as the sampling rate of template and target (typically Ångstrom).",
539
+ )
540
+ filter_group.add_argument(
541
+ "--no-pass-smooth",
542
+ action="store_false",
543
+ default=True,
544
+ help="Whether a hard edge filter should be used for --lowpass and --highpass.",
545
+ )
546
+ filter_group.add_argument(
547
+ "--pass-format",
548
+ type=str,
549
+ required=False,
550
+ default="sampling_rate",
551
+ choices=["sampling_rate", "voxel", "frequency"],
552
+ help="How values passed to --lowpass and --highpass should be interpreted. "
553
+ "Defaults to unit of sampling_rate, e.g., 40 Angstrom.",
554
+ )
555
+ filter_group.add_argument(
556
+ "--whiten-spectrum",
557
+ action="store_true",
558
+ default=None,
559
+ help="Apply spectral whitening to template and target based on target spectrum.",
560
+ )
561
+ filter_group.add_argument(
562
+ "--wedge-axes",
563
+ type=str,
564
+ required=False,
565
+ default="2,0",
566
+ help="Indices of projection (wedge opening) and tilt axis, e.g., '2,0' "
567
+ "for the typical projection over z and tilting over the x-axis.",
568
+ )
569
+ filter_group.add_argument(
570
+ "--tilt-angles",
571
+ type=str,
572
+ required=False,
573
+ default=None,
574
+ help="Path to a file specifying tilt angles. This can be a Warp/M XML file, "
575
+ "a tomostar STAR file, an MMOD file, a tab-separated file with column name "
576
+ "'angles', or a single column file without header. Exposure will be taken from "
577
+ "the input file , if you are using a tab-separated file, the column names "
578
+ "'angles' and 'weights' need to be present. It is also possible to specify a "
579
+ "continuous wedge mask using e.g., -50,45.",
580
+ )
581
+ filter_group.add_argument(
582
+ "--tilt-weighting",
583
+ type=str,
584
+ required=False,
585
+ choices=["angle", "relion", "grigorieff"],
586
+ default=None,
587
+ help="Weighting scheme used to reweight individual tilts. Available options: "
588
+ "angle (cosine based weighting), "
589
+ "relion (relion formalism for wedge weighting) requires,"
590
+ "grigorieff (exposure filter as defined in Grant and Grigorieff 2015)."
591
+ "relion and grigorieff require electron doses in --tilt-angles weights column.",
592
+ )
593
+ filter_group.add_argument(
594
+ "--reconstruction-filter",
595
+ type=str,
596
+ required=False,
597
+ choices=["ram-lak", "ramp", "ramp-cont", "shepp-logan", "cosine", "hamming"],
598
+ default="ramp",
599
+ help="Filter applied when reconstructing (N+1)-D from N-D filters.",
600
+ )
601
+ filter_group.add_argument(
602
+ "--reconstruction-interpolation-order",
603
+ type=int,
604
+ default=1,
605
+ required=False,
606
+ help="Analogous to --interpolation-order but for reconstruction.",
607
+ )
608
+ filter_group.add_argument(
609
+ "--no-filter-target",
610
+ action="store_true",
611
+ default=False,
612
+ help="Whether to not apply potential filters to the target.",
613
+ )
614
+
615
+ ctf_group = parser.add_argument_group("Contrast Transfer Function")
616
+ ctf_group.add_argument(
617
+ "--ctf-file",
618
+ type=str,
619
+ required=False,
620
+ default=None,
621
+ help="Path to a file with CTF parameters. This can be a Warp/M XML file "
622
+ "a GCTF/Relion STAR file, an MDOC file, or the output of CTFFIND4. If the file "
623
+ " does not specify tilt angles, --tilt-angles are used.",
624
+ )
625
+ ctf_group.add_argument(
626
+ "--defocus",
627
+ type=float,
628
+ required=False,
629
+ default=None,
630
+ help="Defocus in units of sampling rate (typically Ångstrom), e.g., 30000 "
631
+ "for a defocus of 3 micrometer. Superseded by --ctf-file.",
632
+ )
633
+ ctf_group.add_argument(
634
+ "--phase-shift",
635
+ type=float,
636
+ required=False,
637
+ default=0,
638
+ help="Phase shift in degrees. Superseded by --ctf-file.",
639
+ )
640
+ ctf_group.add_argument(
641
+ "--acceleration-voltage",
642
+ type=float,
643
+ required=False,
644
+ default=300,
645
+ help="Acceleration voltage in kV.",
646
+ )
647
+ ctf_group.add_argument(
648
+ "--spherical-aberration",
649
+ type=float,
650
+ required=False,
651
+ default=2.7e7,
652
+ help="Spherical aberration in units of sampling rate (typically Ångstrom).",
653
+ )
654
+ ctf_group.add_argument(
655
+ "--amplitude-contrast",
656
+ type=float,
657
+ required=False,
658
+ default=0.07,
659
+ help="Amplitude contrast.",
660
+ )
661
+ ctf_group.add_argument(
662
+ "--no-flip-phase",
663
+ action="store_false",
664
+ required=False,
665
+ help="Do not perform phase-flipping CTF correction.",
666
+ )
667
+ ctf_group.add_argument(
668
+ "--correct-defocus-gradient",
669
+ action="store_true",
670
+ required=False,
671
+ help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
672
+ "defocus gradients.",
673
+ )
674
+
675
+ performance_group = parser.add_argument_group("Performance")
676
+ performance_group.add_argument(
677
+ "--centering",
678
+ action="store_true",
679
+ help="Center the template in the box if it has not been done already.",
680
+ )
681
+ performance_group.add_argument(
682
+ "--pad-edges",
683
+ action="store_true",
684
+ default=False,
685
+ help="Useful if the target does not have a well-defined bounding box. Will be "
686
+ "activated automatically if splitting is required to avoid boundary artifacts.",
687
+ )
688
+ performance_group.add_argument(
689
+ "--pad-filter",
690
+ action="store_true",
691
+ default=False,
692
+ help="Pad the template filter to the shape of the target. Useful for fast "
693
+ "oscilating filters to avoid aliasing effects.",
694
+ )
695
+ performance_group.add_argument(
696
+ "--interpolation-order",
697
+ required=False,
698
+ type=int,
699
+ default=None,
700
+ help="Spline interpolation used for rotations. Defaults to 3, and 1 for jax "
701
+ "and pytorch backends.",
702
+ )
703
+ performance_group.add_argument(
704
+ "--use-mixed-precision",
705
+ action="store_true",
706
+ default=False,
707
+ help="Use float16 for real values operations where possible. Not supported "
708
+ "for jax backend.",
709
+ )
710
+ performance_group.add_argument(
711
+ "--use-memmap",
712
+ action="store_true",
713
+ default=False,
714
+ help="Memmap large data to disk, e.g., matching on unbinned tomograms.",
715
+ )
716
+
717
+ analyzer_group = parser.add_argument_group("Output / Analysis")
718
+ analyzer_group.add_argument(
719
+ "--score-threshold",
720
+ required=False,
721
+ type=float,
722
+ default=0,
723
+ help="Minimum template matching scores to consider for analysis.",
724
+ )
725
+ analyzer_group.add_argument(
726
+ "-p",
727
+ "--peak-calling",
728
+ action="store_true",
729
+ default=False,
730
+ help="Perform peak calling instead of score aggregation.",
731
+ )
732
+ analyzer_group.add_argument(
733
+ "--num-peaks",
734
+ type=int,
735
+ default=1000,
736
+ help="Number of peaks to call, 1000 by default.",
737
+ )
738
+ args = parser.parse_args()
739
+ args.version = __version__
740
+
741
+ if args.interpolation_order is None:
742
+ args.interpolation_order = 3
743
+ if args.backend in ("jax", "pytorch"):
744
+ args.interpolation_order = 1
745
+
746
+ if args.interpolation_order < 0:
747
+ args.interpolation_order = None
748
+
749
+ if args.temp_directory is None:
750
+ args.temp_directory = gettempdir()
751
+
752
+ os.environ["TMPDIR"] = args.temp_directory
753
+ if args.gpu_indices is not None:
754
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_indices
755
+
756
+ if args.tilt_angles is not None and not exists(args.tilt_angles):
757
+ try:
758
+ float(args.tilt_angles.split(",")[0])
759
+ except Exception:
760
+ raise ValueError(f"{args.tilt_angles} is not a file nor a range.")
761
+
762
+ if args.ctf_file is not None and args.tilt_angles is None:
763
+ # Check if tilt angles can be extracted from CTF specification
764
+ try:
765
+ ctf = CTF.from_file(args.ctf_file)
766
+ if ctf.angles is None:
767
+ raise ValueError
768
+ args.tilt_angles = args.ctf_file
769
+ except Exception:
770
+ raise ValueError(
771
+ "Need to specify --tilt-angles when not provided in --ctf-file."
772
+ )
773
+
774
+ args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
775
+ if args.orientations is not None:
776
+ orientations = Orientations.from_file(args.orientations)
777
+ orientations.translations = np.divide(
778
+ orientations.translations, args.orientations_scaling
779
+ )
780
+ args.orientations = orientations
781
+
782
+ args.target = abspath(args.target)
783
+ if args.target_mask is not None:
784
+ args.target_mask = abspath(args.target_mask)
785
+
786
+ args.template = abspath(args.template)
787
+ if args.template_mask is not None:
788
+ args.template_mask = abspath(args.template_mask)
789
+
790
+ return args
791
+
792
+
793
+ def main():
794
+ args = parse_args()
795
+ print_entry()
796
+
797
+ target = Density.from_file(args.target, use_memmap=True)
798
+
799
+ try:
800
+ template = Density.from_file(args.template)
801
+ except Exception:
802
+ template = Density.from_structure(
803
+ filename_or_structure=args.template,
804
+ sampling_rate=target.sampling_rate,
805
+ )
806
+
807
+ if target.sampling_rate.size == template.sampling_rate.size:
808
+ if not np.allclose(
809
+ np.round(target.sampling_rate, 2), np.round(template.sampling_rate, 2)
810
+ ):
811
+ print(
812
+ f"Resampling template to {target.sampling_rate}. "
813
+ "Consider providing a template with the same sampling rate as the target."
814
+ )
815
+ template = template.resample(target.sampling_rate, order=3)
816
+
817
+ template_mask = load_and_validate_mask(
818
+ mask_target=template, mask_path=args.template_mask
819
+ )
820
+ target_mask = load_and_validate_mask(
821
+ mask_target=target, mask_path=args.target_mask, use_memmap=True
822
+ )
823
+
824
+ initial_shape = target.shape
825
+ print_block(
826
+ name="Target",
827
+ data={
828
+ "Initial Shape": initial_shape,
829
+ "Sampling Rate": _format_sampling(target.sampling_rate),
830
+ "Final Shape": target.shape,
831
+ },
832
+ )
833
+
834
+ if target_mask:
835
+ print_block(
836
+ name="Target Mask",
837
+ data={
838
+ "Initial Shape": initial_shape,
839
+ "Sampling Rate": _format_sampling(target_mask.sampling_rate),
840
+ "Final Shape": target_mask.shape,
841
+ },
842
+ )
843
+
844
+ initial_shape = template.shape
845
+ translation = np.zeros(len(template.shape), dtype=np.float32)
846
+ if args.centering:
847
+ template, translation = template.centered(0)
848
+
849
+ print_block(
850
+ name="Template",
851
+ data={
852
+ "Initial Shape": initial_shape,
853
+ "Sampling Rate": _format_sampling(template.sampling_rate),
854
+ "Final Shape": template.shape,
855
+ },
856
+ )
857
+
858
+ if template_mask is None:
859
+ template_mask = template.empty
860
+ if not args.centering:
861
+ enclosing_box = template.minimum_enclosing_box(
862
+ 0, use_geometric_center=False
863
+ )
864
+ template_mask.adjust_box(enclosing_box)
865
+
866
+ template_mask.data[:] = 1
867
+ translation = np.zeros_like(translation)
868
+
869
+ template_mask.pad(template.shape, center=False)
870
+ origin_translation = np.divide(
871
+ np.subtract(template.origin, template_mask.origin), template.sampling_rate
872
+ )
873
+ translation = np.add(translation, origin_translation)
874
+
875
+ template_mask = template_mask.rigid_transform(
876
+ rotation_matrix=np.eye(template_mask.data.ndim),
877
+ translation=-translation,
878
+ order=1,
879
+ )
880
+ template_mask.origin = template.origin.copy()
881
+ print_block(
882
+ name="Template Mask",
883
+ data={
884
+ "Inital Shape": initial_shape,
885
+ "Sampling Rate": _format_sampling(template_mask.sampling_rate),
886
+ "Final Shape": template_mask.shape,
887
+ },
888
+ )
889
+ print("\n" + "-" * 80)
890
+
891
+ if args.scramble_phases:
892
+ template.data = scramble_phases(
893
+ template.data, noise_proportion=1.0, normalize_power=False
894
+ )
895
+
896
+ callback_class = MaxScoreOverRotations
897
+ if args.peak_calling:
898
+ callback_class = PeakCallerMaximumFilter
899
+
900
+ if args.orientations is not None:
901
+ callback_class = MaxScoreOverRotationsConstrained
902
+
903
+ # Determine suitable backend for the selected operation
904
+ available_backends = be.available_backends()
905
+ if args.backend not in available_backends:
906
+ raise ValueError("Requested backend is not available.")
907
+ if args.backend == "jax" and callback_class != MaxScoreOverRotations:
908
+ raise ValueError(
909
+ "Jax backend only supports the MaxScoreOverRotations analyzer."
910
+ )
911
+
912
+ if args.interpolation_order == 3 and args.backend in ("jax", "pytorch"):
913
+ warnings.warn(
914
+ "Jax and pytorch do not support interpolation order 3, setting it to 1."
915
+ )
916
+ args.interpolation_order = 1
917
+
918
+ if args.backend in ("pytorch", "cupy", "jax"):
919
+ gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
920
+ if gpu_devices is None:
921
+ warnings.warn(
922
+ "No GPU indices provided and CUDA_VISIBLE_DEVICES is not set. "
923
+ "Assuming device 0.",
924
+ )
925
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
926
+
927
+ args.cores = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
928
+ args.gpu_indices = [
929
+ int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
930
+ ]
931
+
932
+ # Finally set the desired backend
933
+ device = "cuda"
934
+ args.use_gpu = False
935
+ be.change_backend(args.backend)
936
+ if args.backend in ("jax", "pytorch", "cupy"):
937
+ args.use_gpu = True
938
+
939
+ if args.backend == "pytorch":
940
+ try:
941
+ be.change_backend("pytorch", device=device)
942
+ # Trigger exception if not compiled with device
943
+ be.get_available_memory()
944
+ except Exception as e:
945
+ print(e)
946
+ device = "cpu"
947
+ args.use_gpu = False
948
+ be.change_backend("pytorch", device=device)
949
+
950
+ # TODO: Make the inverse casting from complex64 -> float 16 stable
951
+ # if args.use_mixed_precision:
952
+ # be.change_backend(
953
+ # backend_name=args.backend,
954
+ # float_dtype=be._array_backend.float16,
955
+ # complex_dtype=be._array_backend.complex64,
956
+ # int_dtype=be._array_backend.int16,
957
+ # device=device,
958
+ # )
959
+
960
+ available_memory = be.get_available_memory() * be.device_count()
961
+ if args.memory is None:
962
+ args.memory = int(args.memory_scaling * available_memory)
963
+
964
+ if args.orientations_uncertainty is not None:
965
+ args.orientations_uncertainty = tuple(
966
+ int(x) for x in args.orientations_uncertainty.split(",")
967
+ )
968
+
969
+ matching_data = MatchingData(
970
+ target=target,
971
+ template=template.data,
972
+ target_mask=target_mask,
973
+ template_mask=template_mask,
974
+ invert_target=args.invert_target_contrast,
975
+ rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
976
+ )
977
+
978
+ matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
979
+ matching_data.template_filter, matching_data.target_filter = setup_filter(
980
+ args, template, target
981
+ )
982
+
983
+ matching_data.set_matching_dimension(
984
+ target_dim=target.metadata.get("batch_dimension", None),
985
+ template_dim=template.metadata.get("batch_dimension", None),
986
+ )
987
+ splits, schedule = compute_schedule(args, matching_data, callback_class)
988
+
989
+ n_splits = np.prod(list(splits.values()))
990
+ target_split = ", ".join(
991
+ [":".join([str(x) for x in axis]) for axis in splits.items()]
992
+ )
993
+ gpus_used = 0 if args.gpu_indices is None else len(args.gpu_indices)
994
+ options = {
995
+ "Angular Sampling": f"{args.angular_sampling}"
996
+ f" [{matching_data.rotations.shape[0]} rotations]",
997
+ "Center Template": args.centering,
998
+ "Scramble Template": args.scramble_phases,
999
+ "Invert Contrast": args.invert_target_contrast,
1000
+ "Extend Target Edges": args.pad_edges,
1001
+ "Interpolation Order": args.interpolation_order,
1002
+ "Setup Function": f"{get_func_fullname(matching_setup)}",
1003
+ "Scoring Function": f"{get_func_fullname(matching_score)}",
1004
+ }
1005
+
1006
+ print_block(
1007
+ name="Template Matching",
1008
+ data=options,
1009
+ label_width=max(len(key) for key in options.keys()) + 3,
1010
+ )
1011
+
1012
+ compute_options = {
1013
+ "Backend": be._BACKEND_REGISTRY[be._backend_name],
1014
+ "Compute Devices": f"CPU [{args.cores}], GPU [{gpus_used}]",
1015
+ "Use Mixed Precision": args.use_mixed_precision,
1016
+ "Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
1017
+ "Temporary Directory": args.temp_directory,
1018
+ "Target Splits": f"{target_split} [N={n_splits}]",
1019
+ }
1020
+ print_block(
1021
+ name="Computation",
1022
+ data=compute_options,
1023
+ label_width=max(len(key) for key in options.keys()) + 3,
1024
+ )
1025
+
1026
+ filter_args = {
1027
+ "Lowpass": args.lowpass,
1028
+ "Highpass": args.highpass,
1029
+ "Smooth Pass": args.no_pass_smooth,
1030
+ "Pass Format": args.pass_format,
1031
+ "Spectral Whitening": args.whiten_spectrum,
1032
+ "Wedge Axes": args.wedge_axes,
1033
+ "Tilt Angles": args.tilt_angles,
1034
+ "Tilt Weighting": args.tilt_weighting,
1035
+ "Reconstruction Filter": args.reconstruction_filter,
1036
+ "Extend Filter Grid": args.pad_filter,
1037
+ }
1038
+ if args.ctf_file is not None or args.defocus is not None:
1039
+ filter_args["CTF File"] = args.ctf_file
1040
+ filter_args["Flip Phase"] = args.no_flip_phase
1041
+ filter_args["Correct Defocus"] = args.correct_defocus_gradient
1042
+
1043
+ filter_args = {k: v for k, v in filter_args.items() if v is not None}
1044
+ if len(filter_args):
1045
+ print_block(
1046
+ name="Filters",
1047
+ data=filter_args,
1048
+ label_width=max(len(key) for key in options.keys()) + 3,
1049
+ )
1050
+
1051
+ analyzer_args = {
1052
+ "score_threshold": args.score_threshold,
1053
+ "num_peaks": args.num_peaks,
1054
+ "min_distance": max(template.shape) // 3,
1055
+ "use_memmap": args.use_memmap,
1056
+ }
1057
+ if args.orientations is not None:
1058
+ analyzer_args["reference"] = (0, 0, 1)
1059
+ analyzer_args["cone_angle"] = args.orientations_cone
1060
+ analyzer_args["acceptance_radius"] = args.orientations_uncertainty
1061
+ analyzer_args["positions"] = args.orientations.translations
1062
+ analyzer_args["rotations"] = euler_to_rotationmatrix(
1063
+ args.orientations.rotations
1064
+ )
1065
+
1066
+ print_block(
1067
+ name="Analyzer",
1068
+ data={
1069
+ "Analyzer": callback_class,
1070
+ **{sanitize_name(k): v for k, v in analyzer_args.items()},
1071
+ },
1072
+ label_width=max(len(key) for key in options.keys()) + 3,
1073
+ )
1074
+ print("\n" + "-" * 80)
1075
+
1076
+ outer_jobs = f"{schedule[0]} job{'s' if schedule[0] > 1 else ''}"
1077
+ inner_jobs = f"{schedule[1]} core{'s' if schedule[1] > 1 else ''}"
1078
+ n_splits = f"{n_splits} split{'s' if n_splits > 1 else ''}"
1079
+ print(f"\nDistributing {n_splits} on {outer_jobs} each using {inner_jobs}.")
1080
+
1081
+ start = time()
1082
+ print("Running Template Matching. This might take a while ...")
1083
+ candidates = scan_subsets(
1084
+ matching_data=matching_data,
1085
+ job_schedule=schedule,
1086
+ matching_score=matching_score,
1087
+ matching_setup=matching_setup,
1088
+ callback_class=callback_class,
1089
+ callback_class_args=analyzer_args,
1090
+ target_splits=splits,
1091
+ pad_target_edges=args.pad_edges,
1092
+ pad_template_filter=args.pad_filter,
1093
+ interpolation_order=args.interpolation_order,
1094
+ )
1095
+
1096
+ candidates = list(candidates) if candidates is not None else []
1097
+ candidates.append((target.origin, template.origin, template.sampling_rate, args))
1098
+ write_pickle(data=candidates, filename=args.output)
1099
+
1100
+ runtime = time() - start
1101
+ print("\n" + "-" * 80)
1102
+ print(f"\nRuntime real: {runtime:.3f}s user: {(runtime * args.cores):.3f}s.")
1103
+
1104
+
1105
+ if __name__ == "__main__":
1106
+ main()