pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +1 -5
  2. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/match_template.py +177 -226
  3. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/postprocess.py +69 -47
  4. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocess.py +10 -23
  5. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +98 -28
  6. pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
  7. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/METADATA +15 -15
  8. pytme-0.3.1.dist-info/RECORD +133 -0
  9. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +1 -0
  10. pytme-0.3.1.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +1 -5
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +118 -99
  14. scripts/match_template.py +177 -226
  15. scripts/match_template_filters.py +1200 -0
  16. scripts/postprocess.py +69 -47
  17. scripts/preprocess.py +10 -23
  18. scripts/preprocessor_gui.py +98 -28
  19. scripts/pytme_runner.py +1223 -0
  20. scripts/refine_matches.py +156 -387
  21. tests/data/.DS_Store +0 -0
  22. tests/data/Blurring/.DS_Store +0 -0
  23. tests/data/Maps/.DS_Store +0 -0
  24. tests/data/Raw/.DS_Store +0 -0
  25. tests/data/Structures/.DS_Store +0 -0
  26. tests/preprocessing/test_frequency_filters.py +19 -10
  27. tests/preprocessing/test_utils.py +18 -0
  28. tests/test_analyzer.py +122 -122
  29. tests/test_backends.py +4 -9
  30. tests/test_density.py +0 -1
  31. tests/test_matching_cli.py +30 -30
  32. tests/test_matching_data.py +5 -5
  33. tests/test_matching_utils.py +11 -61
  34. tests/test_rotations.py +1 -1
  35. tme/__version__.py +1 -1
  36. tme/analyzer/__init__.py +1 -1
  37. tme/analyzer/_utils.py +5 -8
  38. tme/analyzer/aggregation.py +28 -9
  39. tme/analyzer/base.py +25 -36
  40. tme/analyzer/peaks.py +49 -122
  41. tme/analyzer/proxy.py +1 -0
  42. tme/backends/_jax_utils.py +31 -28
  43. tme/backends/_numpyfftw_utils.py +270 -0
  44. tme/backends/cupy_backend.py +11 -54
  45. tme/backends/jax_backend.py +72 -48
  46. tme/backends/matching_backend.py +6 -51
  47. tme/backends/mlx_backend.py +1 -27
  48. tme/backends/npfftw_backend.py +95 -90
  49. tme/backends/pytorch_backend.py +5 -26
  50. tme/density.py +7 -10
  51. tme/extensions.cpython-311-darwin.so +0 -0
  52. tme/filters/__init__.py +2 -2
  53. tme/filters/_utils.py +32 -7
  54. tme/filters/bandpass.py +225 -186
  55. tme/filters/ctf.py +138 -87
  56. tme/filters/reconstruction.py +38 -9
  57. tme/filters/wedge.py +98 -112
  58. tme/filters/whitening.py +1 -6
  59. tme/mask.py +341 -0
  60. tme/matching_data.py +20 -44
  61. tme/matching_exhaustive.py +46 -56
  62. tme/matching_optimization.py +2 -1
  63. tme/matching_scores.py +216 -412
  64. tme/matching_utils.py +82 -424
  65. tme/memory.py +1 -1
  66. tme/orientations.py +16 -8
  67. tme/parser.py +109 -29
  68. tme/preprocessor.py +2 -2
  69. tme/rotations.py +1 -1
  70. pytme-0.3b0.dist-info/RECORD +0 -122
  71. pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
  72. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
  73. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
@@ -12,8 +12,8 @@ from sys import exit
12
12
  from time import time
13
13
  from typing import Tuple
14
14
  from copy import deepcopy
15
- from os.path import exists
16
15
  from tempfile import gettempdir
16
+ from os.path import exists, abspath
17
17
 
18
18
  import numpy as np
19
19
 
@@ -36,13 +36,20 @@ from tme.filters import (
36
36
  CTF,
37
37
  Wedge,
38
38
  Compose,
39
- BandPassFilter,
39
+ BandPass,
40
40
  CTFReconstructed,
41
41
  WedgeReconstructed,
42
42
  ReconstructFromTilt,
43
43
  LinearWhiteningFilter,
44
+ BandPassReconstructed,
45
+ )
46
+ from tme.cli import (
47
+ get_func_fullname,
48
+ print_block,
49
+ print_entry,
50
+ check_positive,
51
+ sanitize_name,
44
52
  )
45
- from tme.cli import get_func_fullname, print_block, print_entry, check_positive
46
53
 
47
54
 
48
55
  def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
@@ -116,7 +123,7 @@ def parse_rotation_logic(args, ndim):
116
123
  axis_sampling=args.axis_sampling,
117
124
  n_symmetry=args.axis_symmetry,
118
125
  axis=[0 if i != args.cone_axis else 1 for i in range(ndim)],
119
- reference=[0, 0, -1],
126
+ reference=[0, 0, -1 if args.invert_cone else 1],
120
127
  )
121
128
  return rotations
122
129
 
@@ -158,6 +165,9 @@ def compute_schedule(
158
165
  def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
159
166
  template_filter, target_filter = [], []
160
167
 
168
+ if args.tilt_angles is None:
169
+ args.tilt_angles = args.ctf_file
170
+
161
171
  wedge = None
162
172
  if args.tilt_angles is not None:
163
173
  try:
@@ -177,28 +187,34 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
177
187
  weight_wedge=False,
178
188
  reconstruction_filter=args.reconstruction_filter,
179
189
  )
190
+ wedge.opening_axis, wedge.tilt_axis = args.wedge_axes
180
191
 
181
192
  wedge_target = WedgeReconstructed(
182
193
  angles=wedge.angles,
183
194
  weight_wedge=False,
184
195
  create_continuous_wedge=True,
185
- opening_axis=args.wedge_axes[0],
186
- tilt_axis=args.wedge_axes[1],
196
+ opening_axis=wedge.opening_axis,
197
+ tilt_axis=wedge.tilt_axis,
187
198
  )
188
- wedge.opening_axis = args.wedge_axes[0]
189
- wedge.tilt_axis = args.wedge_axes[1]
199
+
200
+ wedge.sampling_rate = template.sampling_rate
201
+ wedge_target.sampling_rate = template.sampling_rate
190
202
 
191
203
  target_filter.append(wedge_target)
192
204
  template_filter.append(wedge)
193
205
 
194
- args.ctf_file is not None
195
206
  if args.ctf_file is not None or args.defocus is not None:
196
207
  try:
197
- ctf = CTF.from_file(args.ctf_file)
208
+ ctf = CTF.from_file(
209
+ args.ctf_file,
210
+ spherical_aberration=args.spherical_aberration,
211
+ amplitude_contrast=args.amplitude_contrast,
212
+ acceleration_voltage=args.acceleration_voltage * 1e3,
213
+ )
198
214
  if (len(ctf.angles) == 0) and wedge is None:
199
215
  raise ValueError(
200
216
  "You requested to specify the CTF per tilt, but did not specify "
201
- "tilt angles via --tilt_angles or --ctf_file (Warp/M XML format). "
217
+ "tilt angles via --tilt-angles or --ctf-file. "
202
218
  )
203
219
  if len(ctf.angles) == 0:
204
220
  ctf.angles = wedge.angles
@@ -206,20 +222,21 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
206
222
  n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
207
223
  if (n_tilts_ctfs != n_tils_angles) and isinstance(wedge, Wedge):
208
224
  raise ValueError(
209
- f"CTF file contains {n_tilts_ctfs} tilt, but match_template "
210
- f"recieved {n_tils_angles} tilt angles. Expected one angle "
211
- "per tilt."
225
+ f"CTF file contains {n_tilts_ctfs} tilt, but recieved "
226
+ f"{n_tils_angles} tilt angles. Expected one angle per tilt"
212
227
  )
213
228
 
214
229
  except (FileNotFoundError, AttributeError):
215
- ctf = CTFReconstructed(defocus_x=args.defocus, phase_shift=args.phase_shift)
216
-
217
- ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
218
- ctf.sampling_rate = template.sampling_rate
230
+ ctf = CTFReconstructed(
231
+ defocus_x=args.defocus,
232
+ phase_shift=args.phase_shift,
233
+ amplitude_contrast=args.amplitude_contrast,
234
+ spherical_aberration=args.spherical_aberration,
235
+ acceleration_voltage=args.acceleration_voltage * 1e3,
236
+ )
219
237
  ctf.flip_phase = args.no_flip_phase
220
- ctf.amplitude_contrast = args.amplitude_contrast
221
- ctf.spherical_aberration = args.spherical_aberration
222
- ctf.acceleration_voltage = args.acceleration_voltage * 1e3
238
+ ctf.sampling_rate = template.sampling_rate
239
+ ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
223
240
  ctf.correct_defocus_gradient = args.correct_defocus_gradient
224
241
  template_filter.append(ctf)
225
242
 
@@ -242,7 +259,7 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
242
259
  except Exception:
243
260
  pass
244
261
 
245
- bandpass = BandPassFilter(
262
+ bandpass = BandPassReconstructed(
246
263
  use_gaussian=args.no_pass_smooth,
247
264
  lowpass=lowpass,
248
265
  highpass=highpass,
@@ -303,7 +320,6 @@ def parse_args():
303
320
  io_group.add_argument(
304
321
  "-m",
305
322
  "--target",
306
- dest="target",
307
323
  type=str,
308
324
  required=True,
309
325
  help="Path to a target in CCP4/MRC, EM, H5 or another format supported by "
@@ -311,8 +327,8 @@ def parse_args():
311
327
  "https://kosinskilab.github.io/pyTME/reference/api/tme.density.Density.from_file.html",
312
328
  )
313
329
  io_group.add_argument(
314
- "--target_mask",
315
- dest="target_mask",
330
+ "-M",
331
+ "--target-mask",
316
332
  type=str,
317
333
  required=False,
318
334
  help="Path to a mask for the target in a supported format (see target).",
@@ -320,14 +336,13 @@ def parse_args():
320
336
  io_group.add_argument(
321
337
  "-i",
322
338
  "--template",
323
- dest="template",
324
339
  type=str,
325
340
  required=True,
326
341
  help="Path to a template in PDB/MMCIF or other supported formats (see target).",
327
342
  )
328
343
  io_group.add_argument(
329
- "--template_mask",
330
- dest="template_mask",
344
+ "-I",
345
+ "--template-mask",
331
346
  type=str,
332
347
  required=False,
333
348
  help="Path to a mask for the template in a supported format (see target).",
@@ -335,23 +350,20 @@ def parse_args():
335
350
  io_group.add_argument(
336
351
  "-o",
337
352
  "--output",
338
- dest="output",
339
353
  type=str,
340
354
  required=False,
341
355
  default="output.pickle",
342
356
  help="Path to the output pickle file.",
343
357
  )
344
358
  io_group.add_argument(
345
- "--invert_target_contrast",
346
- dest="invert_target_contrast",
359
+ "--invert-target-contrast",
347
360
  action="store_true",
348
361
  default=False,
349
362
  help="Invert the target's contrast for cases where templates to-be-matched have "
350
363
  "negative values, e.g. tomograms.",
351
364
  )
352
365
  io_group.add_argument(
353
- "--scramble_phases",
354
- dest="scramble_phases",
366
+ "--scramble-phases",
355
367
  action="store_true",
356
368
  default=False,
357
369
  help="Phase scramble the template to generate a noise score background.",
@@ -360,30 +372,29 @@ def parse_args():
360
372
  sampling_group = parser.add_argument_group("Sampling")
361
373
  sampling_group.add_argument(
362
374
  "--orientations",
363
- dest="orientations",
364
375
  default=None,
365
376
  required=False,
366
- help="Path to a file readable via Orientations.from_file containing "
377
+ help="Path to a file readable by Orientations.from_file containing "
367
378
  "translations and rotations of candidate peaks to refine.",
368
379
  )
369
380
  sampling_group.add_argument(
370
- "--orientations_scaling",
381
+ "--orientations-scaling",
371
382
  required=False,
372
383
  type=float,
373
384
  default=1.0,
374
385
  help="Scaling factor to map candidate translations onto the target. "
375
386
  "Assuming coordinates are in Å and target sampling rate are 3Å/voxel, "
376
- "the corresponding orientations_scaling would be 3.",
387
+ "the corresponding --orientations-scaling would be 3.",
377
388
  )
378
389
  sampling_group.add_argument(
379
- "--orientations_cone",
390
+ "--orientations-cone",
380
391
  required=False,
381
392
  type=float,
382
393
  default=20.0,
383
394
  help="Accept orientations within specified cone angle of each orientation.",
384
395
  )
385
396
  sampling_group.add_argument(
386
- "--orientations_uncertainty",
397
+ "--orientations-uncertainty",
387
398
  required=False,
388
399
  type=str,
389
400
  default="10",
@@ -394,7 +405,7 @@ def parse_args():
394
405
  scoring_group = parser.add_argument_group("Scoring")
395
406
  scoring_group.add_argument(
396
407
  "-s",
397
- dest="score",
408
+ "--score",
398
409
  type=str,
399
410
  default="FLCSphericalMask",
400
411
  choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
@@ -406,74 +417,64 @@ def parse_args():
406
417
 
407
418
  angular_exclusive.add_argument(
408
419
  "-a",
409
- dest="angular_sampling",
420
+ "--angular-sampling",
410
421
  type=check_positive,
411
422
  default=None,
412
- help="Angular sampling rate using optimized rotational sets."
413
- "A lower number yields more rotations. Values >= 180 sample only the identity.",
423
+ help="Angular sampling rate. Lower values = more rotations, higher precision.",
414
424
  )
415
425
  angular_exclusive.add_argument(
416
- "--cone_angle",
417
- dest="cone_angle",
426
+ "--cone-angle",
418
427
  type=check_positive,
419
428
  default=None,
420
429
  help="Half-angle of the cone to be sampled in degrees. Allows to sample a "
421
430
  "narrow interval around a known orientation, e.g. for surface oversampling.",
422
431
  )
423
432
  angular_exclusive.add_argument(
424
- "--particle_diameter",
425
- dest="particle_diameter",
433
+ "--particle-diameter",
426
434
  type=check_positive,
427
435
  default=None,
428
436
  help="Particle diameter in units of sampling rate.",
429
437
  )
430
438
  angular_group.add_argument(
431
- "--cone_axis",
432
- dest="cone_axis",
439
+ "--cone-axis",
433
440
  type=check_positive,
434
441
  default=2,
435
442
  help="Principal axis to build cone around.",
436
443
  )
437
444
  angular_group.add_argument(
438
- "--invert_cone",
439
- dest="invert_cone",
445
+ "--invert-cone",
440
446
  action="store_true",
441
447
  help="Invert cone handedness.",
442
448
  )
443
449
  angular_group.add_argument(
444
- "--cone_sampling",
445
- dest="cone_sampling",
450
+ "--cone-sampling",
446
451
  type=check_positive,
447
452
  default=None,
448
453
  help="Sampling rate of the cone in degrees.",
449
454
  )
450
455
  angular_group.add_argument(
451
- "--axis_angle",
452
- dest="axis_angle",
456
+ "--axis-angle",
453
457
  type=check_positive,
454
458
  default=360.0,
455
459
  required=False,
456
460
  help="Sampling angle along the z-axis of the cone.",
457
461
  )
458
462
  angular_group.add_argument(
459
- "--axis_sampling",
460
- dest="axis_sampling",
463
+ "--axis-sampling",
461
464
  type=check_positive,
462
465
  default=None,
463
466
  required=False,
464
- help="Sampling rate along the z-axis of the cone. Defaults to --cone_sampling.",
467
+ help="Sampling rate along the z-axis of the cone. Defaults to --cone-sampling.",
465
468
  )
466
469
  angular_group.add_argument(
467
- "--axis_symmetry",
468
- dest="axis_symmetry",
470
+ "--axis-symmetry",
469
471
  type=check_positive,
470
472
  default=1,
471
473
  required=False,
472
474
  help="N-fold symmetry around z-axis of the cone.",
473
475
  )
474
476
  angular_group.add_argument(
475
- "--no_use_optimized_set",
476
- dest="no_use_optimized_set",
477
+ "--no-use-optimized-set",
477
478
  action="store_true",
478
479
  default=False,
479
480
  required=False,
@@ -490,57 +491,40 @@ def parse_args():
490
491
  help="Number of cores used for template matching.",
491
492
  )
492
493
  computation_group.add_argument(
493
- "--use_gpu",
494
- dest="use_gpu",
495
- action="store_true",
496
- default=False,
497
- help="Whether to perform computations on the GPU.",
498
- )
499
- computation_group.add_argument(
500
- "--gpu_indices",
501
- dest="gpu_indices",
494
+ "--gpu-indices",
502
495
  type=str,
503
496
  default=None,
504
- help="Comma-separated list of GPU indices to use. For example,"
505
- " 0,1 for the first and second GPU. Only used if --use_gpu is set."
506
- " If not provided but --use_gpu is set, CUDA_VISIBLE_DEVICES will"
507
- " be respected.",
497
+ help="Comma-separated GPU indices (e.g., '0,1,2' for first 3 GPUs). Otherwise "
498
+ "CUDA_VISIBLE_DEVICES will be used.",
508
499
  )
509
500
  computation_group.add_argument(
510
- "-r",
511
- "--ram",
512
501
  "--memory",
513
- dest="memory",
514
502
  required=False,
515
503
  type=int,
516
504
  default=None,
517
505
  help="Amount of memory that can be used in bytes.",
518
506
  )
519
507
  computation_group.add_argument(
520
- "--memory_scaling",
521
- dest="memory_scaling",
508
+ "--memory-scaling",
522
509
  required=False,
523
510
  type=float,
524
511
  default=0.85,
525
512
  help="Fraction of available memory to be used. Ignored if --memory is set.",
526
513
  )
527
514
  computation_group.add_argument(
528
- "--temp_directory",
529
- dest="temp_directory",
530
- default=None,
531
- help="Directory for temporary objects. Faster I/O improves runtime.",
515
+ "--temp-directory",
516
+ default=gettempdir(),
517
+ help="Temporary directory for memmaps. Better I/O improves runtime.",
532
518
  )
533
519
  computation_group.add_argument(
534
520
  "--backend",
535
- dest="backend",
536
521
  default=be._backend_name,
537
522
  choices=be.available_backends(),
538
- help="[Expert] Overwrite default computation backend.",
523
+ help="Set computation backend.",
539
524
  )
540
525
  filter_group = parser.add_argument_group("Filters")
541
526
  filter_group.add_argument(
542
527
  "--lowpass",
543
- dest="lowpass",
544
528
  type=float,
545
529
  required=False,
546
530
  help="Resolution to lowpass filter template and target to in the same unit "
@@ -548,22 +532,19 @@ def parse_args():
548
532
  )
549
533
  filter_group.add_argument(
550
534
  "--highpass",
551
- dest="highpass",
552
535
  type=float,
553
536
  required=False,
554
537
  help="Resolution to highpass filter template and target to in the same unit "
555
538
  "as the sampling rate of template and target (typically Ångstrom).",
556
539
  )
557
540
  filter_group.add_argument(
558
- "--no_pass_smooth",
559
- dest="no_pass_smooth",
541
+ "--no-pass-smooth",
560
542
  action="store_false",
561
543
  default=True,
562
544
  help="Whether a hard edge filter should be used for --lowpass and --highpass.",
563
545
  )
564
546
  filter_group.add_argument(
565
- "--pass_format",
566
- dest="pass_format",
547
+ "--pass-format",
567
548
  type=str,
568
549
  required=False,
569
550
  default="sampling_rate",
@@ -572,15 +553,13 @@ def parse_args():
572
553
  "Defaults to unit of sampling_rate, e.g., 40 Angstrom.",
573
554
  )
574
555
  filter_group.add_argument(
575
- "--whiten_spectrum",
576
- dest="whiten_spectrum",
556
+ "--whiten-spectrum",
577
557
  action="store_true",
578
558
  default=None,
579
559
  help="Apply spectral whitening to template and target based on target spectrum.",
580
560
  )
581
561
  filter_group.add_argument(
582
- "--wedge_axes",
583
- dest="wedge_axes",
562
+ "--wedge-axes",
584
563
  type=str,
585
564
  required=False,
586
565
  default="2,0",
@@ -588,21 +567,19 @@ def parse_args():
588
567
  "for the typical projection over z and tilting over the x-axis.",
589
568
  )
590
569
  filter_group.add_argument(
591
- "--tilt_angles",
592
- dest="tilt_angles",
570
+ "--tilt-angles",
593
571
  type=str,
594
572
  required=False,
595
573
  default=None,
596
574
  help="Path to a file specifying tilt angles. This can be a Warp/M XML file, "
597
- "a tomostar STAR file, a tab-separated file with column name 'angles', or a "
598
- "single column file without header. Exposure will be taken from the input file "
599
- ", if you are using a tab-separated file, the column names 'angles' and "
600
- "'weights' need to be present. It is also possible to specify a continuous "
601
- "wedge mask using e.g., -50,45.",
575
+ "a tomostar STAR file, an MMOD file, a tab-separated file with column name "
576
+ "'angles', or a single column file without header. Exposure will be taken from "
577
+ "the input file , if you are using a tab-separated file, the column names "
578
+ "'angles' and 'weights' need to be present. It is also possible to specify a "
579
+ "continuous wedge mask using e.g., 50,45.",
602
580
  )
603
581
  filter_group.add_argument(
604
- "--tilt_weighting",
605
- dest="tilt_weighting",
582
+ "--tilt-weighting",
606
583
  type=str,
607
584
  required=False,
608
585
  choices=["angle", "relion", "grigorieff"],
@@ -611,28 +588,25 @@ def parse_args():
611
588
  "angle (cosine based weighting), "
612
589
  "relion (relion formalism for wedge weighting) requires,"
613
590
  "grigorieff (exposure filter as defined in Grant and Grigorieff 2015)."
614
- "relion and grigorieff require electron doses in --tilt_angles weights column.",
591
+ "relion and grigorieff require electron doses in --tilt-angles weights column.",
615
592
  )
616
593
  filter_group.add_argument(
617
- "--reconstruction_filter",
618
- dest="reconstruction_filter",
594
+ "--reconstruction-filter",
619
595
  type=str,
620
596
  required=False,
621
597
  choices=["ram-lak", "ramp", "ramp-cont", "shepp-logan", "cosine", "hamming"],
622
- default=None,
598
+ default="ramp",
623
599
  help="Filter applied when reconstructing (N+1)-D from N-D filters.",
624
600
  )
625
601
  filter_group.add_argument(
626
- "--reconstruction_interpolation_order",
627
- dest="reconstruction_interpolation_order",
602
+ "--reconstruction-interpolation-order",
628
603
  type=int,
629
604
  default=1,
630
605
  required=False,
631
- help="Analogous to --interpolation_order but for reconstruction.",
606
+ help="Analogous to --interpolation-order but for reconstruction.",
632
607
  )
633
608
  filter_group.add_argument(
634
- "--no_filter_target",
635
- dest="no_filter_target",
609
+ "--no-filter-target",
636
610
  action="store_true",
637
611
  default=False,
638
612
  help="Whether to not apply potential filters to the target.",
@@ -640,66 +614,58 @@ def parse_args():
640
614
 
641
615
  ctf_group = parser.add_argument_group("Contrast Transfer Function")
642
616
  ctf_group.add_argument(
643
- "--ctf_file",
644
- dest="ctf_file",
617
+ "--ctf-file",
645
618
  type=str,
646
619
  required=False,
647
620
  default=None,
648
621
  help="Path to a file with CTF parameters. This can be a Warp/M XML file "
649
- "a GCTF/Relion STAR file, or the output of CTFFIND4. If the file does not "
650
- "specify tilt angles, the angles specified with --tilt_angles are used.",
622
+ "a GCTF/Relion STAR file, an MDOC file, or the output of CTFFIND4. If the file "
623
+ " does not specify tilt angles, --tilt-angles are used.",
651
624
  )
652
625
  ctf_group.add_argument(
653
626
  "--defocus",
654
- dest="defocus",
655
627
  type=float,
656
628
  required=False,
657
629
  default=None,
658
630
  help="Defocus in units of sampling rate (typically Ångstrom), e.g., 30000 "
659
- "for a defocus of 3 micrometer. Superseded by --ctf_file.",
631
+ "for a defocus of 3 micrometer. Superseded by --ctf-file.",
660
632
  )
661
633
  ctf_group.add_argument(
662
- "--phase_shift",
663
- dest="phase_shift",
634
+ "--phase-shift",
664
635
  type=float,
665
636
  required=False,
666
637
  default=0,
667
- help="Phase shift in degrees. Superseded by --ctf_file.",
638
+ help="Phase shift in degrees. Superseded by --ctf-file.",
668
639
  )
669
640
  ctf_group.add_argument(
670
- "--acceleration_voltage",
671
- dest="acceleration_voltage",
641
+ "--acceleration-voltage",
672
642
  type=float,
673
643
  required=False,
674
644
  default=300,
675
645
  help="Acceleration voltage in kV.",
676
646
  )
677
647
  ctf_group.add_argument(
678
- "--spherical_aberration",
679
- dest="spherical_aberration",
648
+ "--spherical-aberration",
680
649
  type=float,
681
650
  required=False,
682
651
  default=2.7e7,
683
652
  help="Spherical aberration in units of sampling rate (typically Ångstrom).",
684
653
  )
685
654
  ctf_group.add_argument(
686
- "--amplitude_contrast",
687
- dest="amplitude_contrast",
655
+ "--amplitude-contrast",
688
656
  type=float,
689
657
  required=False,
690
658
  default=0.07,
691
659
  help="Amplitude contrast.",
692
660
  )
693
661
  ctf_group.add_argument(
694
- "--no_flip_phase",
695
- dest="no_flip_phase",
662
+ "--no-flip-phase",
696
663
  action="store_false",
697
664
  required=False,
698
665
  help="Do not perform phase-flipping CTF correction.",
699
666
  )
700
667
  ctf_group.add_argument(
701
- "--correct_defocus_gradient",
702
- dest="correct_defocus_gradient",
668
+ "--correct-defocus-gradient",
703
669
  action="store_true",
704
670
  required=False,
705
671
  help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
@@ -708,56 +674,35 @@ def parse_args():
708
674
 
709
675
  performance_group = parser.add_argument_group("Performance")
710
676
  performance_group.add_argument(
711
- "--no_centering",
712
- dest="no_centering",
713
- action="store_true",
714
- help="Assumes the template is already centered and omits centering.",
715
- )
716
- performance_group.add_argument(
717
- "--pad_edges",
718
- dest="pad_edges",
677
+ "--centering",
719
678
  action="store_true",
720
- default=False,
721
- help="Whether to pad the edges of the target. Useful if the target does not "
722
- "a well-defined bounding box. Defaults to True if splitting is required.",
679
+ help="Center the template in the box if it has not been done already.",
723
680
  )
724
681
  performance_group.add_argument(
725
- "--pad_filter",
726
- dest="pad_filter",
682
+ "--pad-edges",
727
683
  action="store_true",
728
684
  default=False,
729
- help="Pads the filter to the shape of the target. Particularly useful for fast "
730
- "oscilating filters to avoid aliasing effects.",
685
+ help="Useful if the target does not have a well-defined bounding box. Will be "
686
+ "activated automatically if splitting is required to avoid boundary artifacts.",
731
687
  )
732
688
  performance_group.add_argument(
733
- "--interpolation_order",
734
- dest="interpolation_order",
689
+ "--interpolation-order",
735
690
  required=False,
736
691
  type=int,
737
- default=3,
738
- help="Spline interpolation used for rotations.",
739
- )
740
- performance_group.add_argument(
741
- "--use_mixed_precision",
742
- dest="use_mixed_precision",
743
- action="store_true",
744
- default=False,
745
- help="Use float16 for real values operations where possible. Not supported "
746
- "for jax backend.",
692
+ default=None,
693
+ help="Spline interpolation used for rotations. Defaults to 3, and 1 for jax "
694
+ "and pytorch backends.",
747
695
  )
748
696
  performance_group.add_argument(
749
- "--use_memmap",
750
- dest="use_memmap",
697
+ "--use-memmap",
751
698
  action="store_true",
752
699
  default=False,
753
- help="Use memmaps to offload large data objects to disk. "
754
- "Particularly useful for large inputs in combination with --use_gpu.",
700
+ help="Memmap large data to disk, e.g., matching on unbinned tomograms.",
755
701
  )
756
702
 
757
- analyzer_group = parser.add_argument_group("Analyzer")
703
+ analyzer_group = parser.add_argument_group("Output / Analysis")
758
704
  analyzer_group.add_argument(
759
- "--score_threshold",
760
- dest="score_threshold",
705
+ "--score-threshold",
761
706
  required=False,
762
707
  type=float,
763
708
  default=0,
@@ -765,21 +710,26 @@ def parse_args():
765
710
  )
766
711
  analyzer_group.add_argument(
767
712
  "-p",
768
- dest="peak_calling",
713
+ "--peak-calling",
769
714
  action="store_true",
770
715
  default=False,
771
716
  help="Perform peak calling instead of score aggregation.",
772
717
  )
773
718
  analyzer_group.add_argument(
774
- "--num_peaks",
775
- dest="num_peaks",
776
- action="store_true",
719
+ "--num-peaks",
720
+ type=int,
777
721
  default=1000,
778
722
  help="Number of peaks to call, 1000 by default.",
779
723
  )
780
724
  args = parser.parse_args()
781
725
  args.version = __version__
782
726
 
727
+ if args.interpolation_order is None:
728
+ args.interpolation_order = 3
729
+ if args.backend in ("jax", "pytorch"):
730
+ args.interpolation_order = 1
731
+ args.reconstruction_interpolation_order = 1
732
+
783
733
  if args.interpolation_order < 0:
784
734
  args.interpolation_order = None
785
735
 
@@ -787,35 +737,28 @@ def parse_args():
787
737
  args.temp_directory = gettempdir()
788
738
 
789
739
  os.environ["TMPDIR"] = args.temp_directory
790
- if args.score not in MATCHING_EXHAUSTIVE_REGISTER:
791
- raise ValueError(
792
- f"score has to be one of {', '.join(MATCHING_EXHAUSTIVE_REGISTER.keys())}"
793
- )
794
-
795
740
  if args.gpu_indices is not None:
796
741
  os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_indices
797
742
 
798
- if args.use_gpu:
799
- warnings.warn(
800
- "The use_gpu flag is no longer required and automatically "
801
- "determined based on the selected backend."
802
- )
803
-
804
- if args.tilt_angles is not None:
805
- if args.wedge_axes is None:
806
- raise ValueError("Need to specify --wedge_axes when --tilt_angles is set.")
807
- if not exists(args.tilt_angles):
808
- try:
809
- float(args.tilt_angles.split(",")[0])
810
- except ValueError:
811
- raise ValueError(f"{args.tilt_angles} is not a file nor a range.")
743
+ if args.tilt_angles is not None and not exists(args.tilt_angles):
744
+ try:
745
+ float(args.tilt_angles.split(",")[0])
746
+ except Exception:
747
+ raise ValueError(f"{args.tilt_angles} is not a file nor a range.")
812
748
 
813
749
  if args.ctf_file is not None and args.tilt_angles is None:
814
- raise ValueError("Need to specify --tilt_angles when --ctf_file is set.")
815
-
816
- if args.wedge_axes is not None:
817
- args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
750
+ # Check if tilt angles can be extracted from CTF specification
751
+ try:
752
+ ctf = CTF.from_file(args.ctf_file)
753
+ if ctf.angles is None:
754
+ raise ValueError
755
+ args.tilt_angles = args.ctf_file
756
+ except Exception:
757
+ raise ValueError(
758
+ "Need to specify --tilt-angles when not provided in --ctf-file."
759
+ )
818
760
 
761
+ args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
819
762
  if args.orientations is not None:
820
763
  orientations = Orientations.from_file(args.orientations)
821
764
  orientations.translations = np.divide(
@@ -823,6 +766,14 @@ def parse_args():
823
766
  )
824
767
  args.orientations = orientations
825
768
 
769
+ args.target = abspath(args.target)
770
+ if args.target_mask is not None:
771
+ args.target_mask = abspath(args.target_mask)
772
+
773
+ args.template = abspath(args.template)
774
+ if args.template_mask is not None:
775
+ args.template_mask = abspath(args.template_mask)
776
+
826
777
  return args
827
778
 
828
779
 
@@ -840,15 +791,23 @@ def main():
840
791
  sampling_rate=target.sampling_rate,
841
792
  )
842
793
 
794
+ if np.allclose(target.sampling_rate, 1):
795
+ warnings.warn(
796
+ "Target sampling rate is 1.0, which may indicate missing or incorrect "
797
+ "metadata. Verify that your target file contains proper sampling rate "
798
+ "information, as filters (CTF, BandPass) require accurate sampling rates "
799
+ "to function correctly."
800
+ )
801
+
843
802
  if target.sampling_rate.size == template.sampling_rate.size:
844
803
  if not np.allclose(
845
804
  np.round(target.sampling_rate, 2), np.round(template.sampling_rate, 2)
846
805
  ):
847
- print(
848
- f"Resampling template to {target.sampling_rate}. "
849
- "Consider providing a template with the same sampling rate as the target."
806
+ warnings.warn(
807
+ f"Sampling rate mismatch detected: target={target.sampling_rate} "
808
+ f"template={template.sampling_rate}. Proceeding with user-provided "
809
+ f"values. Make sure this is intentional. "
850
810
  )
851
- template = template.resample(target.sampling_rate, order=3)
852
811
 
853
812
  template_mask = load_and_validate_mask(
854
813
  mask_target=template, mask_path=args.template_mask
@@ -879,8 +838,9 @@ def main():
879
838
 
880
839
  initial_shape = template.shape
881
840
  translation = np.zeros(len(template.shape), dtype=np.float32)
882
- if not args.no_centering:
841
+ if args.centering:
883
842
  template, translation = template.centered(0)
843
+
884
844
  print_block(
885
845
  name="Template",
886
846
  data={
@@ -892,7 +852,7 @@ def main():
892
852
 
893
853
  if template_mask is None:
894
854
  template_mask = template.empty
895
- if not args.no_centering:
855
+ if not args.centering:
896
856
  enclosing_box = template.minimum_enclosing_box(
897
857
  0, use_geometric_center=False
898
858
  )
@@ -924,16 +884,13 @@ def main():
924
884
  print("\n" + "-" * 80)
925
885
 
926
886
  if args.scramble_phases:
927
- template.data = scramble_phases(
928
- template.data, noise_proportion=1.0, normalize_power=False
929
- )
887
+ template.data = scramble_phases(template.data, noise_proportion=1.0)
930
888
 
931
889
  callback_class = MaxScoreOverRotations
932
- if args.peak_calling:
933
- callback_class = PeakCallerMaximumFilter
934
-
935
890
  if args.orientations is not None:
936
891
  callback_class = MaxScoreOverRotationsConstrained
892
+ elif args.peak_calling:
893
+ callback_class = PeakCallerMaximumFilter
937
894
 
938
895
  # Determine suitable backend for the selected operation
939
896
  available_backends = be.available_backends()
@@ -958,15 +915,19 @@ def main():
958
915
  "Assuming device 0.",
959
916
  )
960
917
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
961
- else:
962
- args.cores = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
918
+
919
+ args.cores = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
963
920
  args.gpu_indices = [
964
921
  int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
965
922
  ]
966
923
 
967
924
  # Finally set the desired backend
968
925
  device = "cuda"
926
+ args.use_gpu = False
969
927
  be.change_backend(args.backend)
928
+ if args.backend in ("jax", "pytorch", "cupy"):
929
+ args.use_gpu = True
930
+
970
931
  if args.backend == "pytorch":
971
932
  try:
972
933
  be.change_backend("pytorch", device=device)
@@ -975,15 +936,8 @@ def main():
975
936
  except Exception as e:
976
937
  print(e)
977
938
  device = "cpu"
939
+ args.use_gpu = False
978
940
  be.change_backend("pytorch", device=device)
979
- if args.use_mixed_precision:
980
- be.change_backend(
981
- backend_name=args.backend,
982
- default_dtype=be._array_backend.float16,
983
- complex_dtype=be._array_backend.complex64,
984
- default_dtype_int=be._array_backend.int16,
985
- device=device,
986
- )
987
941
 
988
942
  available_memory = be.get_available_memory() * be.device_count()
989
943
  if args.memory is None:
@@ -1012,6 +966,8 @@ def main():
1012
966
  target_dim=target.metadata.get("batch_dimension", None),
1013
967
  template_dim=template.metadata.get("batch_dimension", None),
1014
968
  )
969
+ args.batch_dims = tuple(int(x) for x in np.where(matching_data._batch_mask)[0])
970
+
1015
971
  splits, schedule = compute_schedule(args, matching_data, callback_class)
1016
972
 
1017
973
  n_splits = np.prod(list(splits.values()))
@@ -1022,7 +978,7 @@ def main():
1022
978
  options = {
1023
979
  "Angular Sampling": f"{args.angular_sampling}"
1024
980
  f" [{matching_data.rotations.shape[0]} rotations]",
1025
- "Center Template": not args.no_centering,
981
+ "Center Template": args.centering,
1026
982
  "Scramble Template": args.scramble_phases,
1027
983
  "Invert Contrast": args.invert_target_contrast,
1028
984
  "Extend Target Edges": args.pad_edges,
@@ -1040,7 +996,6 @@ def main():
1040
996
  compute_options = {
1041
997
  "Backend": be._BACKEND_REGISTRY[be._backend_name],
1042
998
  "Compute Devices": f"CPU [{args.cores}], GPU [{gpus_used}]",
1043
- "Use Mixed Precision": args.use_mixed_precision,
1044
999
  "Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
1045
1000
  "Temporary Directory": args.temp_directory,
1046
1001
  "Target Splits": f"{target_split} [N={n_splits}]",
@@ -1061,16 +1016,10 @@ def main():
1061
1016
  "Tilt Angles": args.tilt_angles,
1062
1017
  "Tilt Weighting": args.tilt_weighting,
1063
1018
  "Reconstruction Filter": args.reconstruction_filter,
1064
- "Extend Filter Grid": args.pad_filter,
1065
1019
  }
1066
1020
  if args.ctf_file is not None or args.defocus is not None:
1067
1021
  filter_args["CTF File"] = args.ctf_file
1068
- filter_args["Defocus"] = args.defocus
1069
- filter_args["Phase Shift"] = args.phase_shift
1070
1022
  filter_args["Flip Phase"] = args.no_flip_phase
1071
- filter_args["Acceleration Voltage"] = args.acceleration_voltage
1072
- filter_args["Spherical Aberration"] = args.spherical_aberration
1073
- filter_args["Amplitude Contrast"] = args.amplitude_contrast
1074
1023
  filter_args["Correct Defocus"] = args.correct_defocus_gradient
1075
1024
 
1076
1025
  filter_args = {k: v for k, v in filter_args.items() if v is not None}
@@ -1098,7 +1047,10 @@ def main():
1098
1047
 
1099
1048
  print_block(
1100
1049
  name="Analyzer",
1101
- data={"Analyzer": callback_class, **analyzer_args},
1050
+ data={
1051
+ "Analyzer": callback_class,
1052
+ **{sanitize_name(k): v for k, v in analyzer_args.items()},
1053
+ },
1102
1054
  label_width=max(len(key) for key in options.keys()) + 3,
1103
1055
  )
1104
1056
  print("\n" + "-" * 80)
@@ -1119,7 +1071,6 @@ def main():
1119
1071
  callback_class_args=analyzer_args,
1120
1072
  target_splits=splits,
1121
1073
  pad_target_edges=args.pad_edges,
1122
- pad_template_filter=args.pad_filter,
1123
1074
  interpolation_order=args.interpolation_order,
1124
1075
  )
1125
1076