pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/match_template.py +473 -140
  2. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
  3. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
  4. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/METADATA +2 -2
  5. pytme-0.2.1.dist-info/RECORD +73 -0
  6. scripts/extract_candidates.py +117 -85
  7. scripts/match_template.py +473 -140
  8. scripts/match_template_filters.py +458 -169
  9. scripts/postprocess.py +107 -49
  10. scripts/preprocessor_gui.py +4 -1
  11. scripts/refine_matches.py +364 -160
  12. tme/__version__.py +1 -1
  13. tme/analyzer.py +278 -148
  14. tme/backends/__init__.py +1 -0
  15. tme/backends/cupy_backend.py +20 -13
  16. tme/backends/jax_backend.py +218 -0
  17. tme/backends/matching_backend.py +25 -10
  18. tme/backends/mlx_backend.py +13 -9
  19. tme/backends/npfftw_backend.py +22 -12
  20. tme/backends/pytorch_backend.py +20 -9
  21. tme/density.py +85 -64
  22. tme/extensions.cpython-311-darwin.so +0 -0
  23. tme/matching_data.py +86 -60
  24. tme/matching_exhaustive.py +245 -166
  25. tme/matching_optimization.py +137 -69
  26. tme/matching_utils.py +1 -1
  27. tme/orientations.py +175 -55
  28. tme/preprocessing/__init__.py +2 -0
  29. tme/preprocessing/_utils.py +188 -0
  30. tme/preprocessing/composable_filter.py +31 -0
  31. tme/preprocessing/compose.py +51 -0
  32. tme/preprocessing/frequency_filters.py +378 -0
  33. tme/preprocessing/tilt_series.py +1017 -0
  34. tme/preprocessor.py +17 -7
  35. tme/structure.py +4 -1
  36. pytme-0.2.0b0.dist-info/RECORD +0 -66
  37. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
  38. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
  39. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
  40. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
  41. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
  42. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
@@ -11,14 +11,16 @@ import warnings
11
11
  import importlib.util
12
12
  from sys import exit
13
13
  from time import time
14
+ from typing import Tuple
14
15
  from copy import deepcopy
15
- from os.path import abspath
16
+ from os.path import abspath, exists
16
17
 
17
18
  import numpy as np
18
19
 
19
- from tme import Density, Preprocessor, __version__
20
+ from tme import Density, __version__
20
21
  from tme.matching_utils import (
21
22
  get_rotation_matrices,
23
+ get_rotations_around_vector,
22
24
  compute_parallelization_schedule,
23
25
  euler_from_rotationmatrix,
24
26
  scramble_phases,
@@ -32,6 +34,7 @@ from tme.analyzer import (
32
34
  PeakCallerMaximumFilter,
33
35
  )
34
36
  from tme.backends import backend
37
+ from tme.preprocessing import Compose
35
38
 
36
39
 
37
40
  def get_func_fullname(func) -> str:
@@ -150,6 +153,187 @@ def crop_data(data: Density, cutoff: float, data_mask: Density = None) -> bool:
150
153
  return True
151
154
 
152
155
 
156
+ def parse_rotation_logic(args, ndim):
157
+ if args.angular_sampling is not None:
158
+ rotations = get_rotation_matrices(
159
+ angular_sampling=args.angular_sampling,
160
+ dim=ndim,
161
+ use_optimized_set=not args.no_use_optimized_set,
162
+ )
163
+ if args.angular_sampling >= 180:
164
+ rotations = np.eye(ndim).reshape(1, ndim, ndim)
165
+ return rotations
166
+
167
+ if args.axis_sampling is None:
168
+ args.axis_sampling = args.cone_sampling
169
+
170
+ rotations = get_rotations_around_vector(
171
+ cone_angle=args.cone_angle,
172
+ cone_sampling=args.cone_sampling,
173
+ axis_angle=args.axis_angle,
174
+ axis_sampling=args.axis_sampling,
175
+ n_symmetry=args.axis_symmetry,
176
+ )
177
+ return rotations
178
+
179
+
180
+ # TODO: Think about whether wedge mask should also be added to target
181
+ # For now leave it at the cost of incorrect upper bound on the scores
182
+ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
183
+ from tme.preprocessing import LinearWhiteningFilter, BandPassFilter
184
+ from tme.preprocessing.tilt_series import (
185
+ Wedge,
186
+ WedgeReconstructed,
187
+ ReconstructFromTilt,
188
+ )
189
+
190
+ template_filter, target_filter = [], []
191
+ if args.tilt_angles is not None:
192
+ try:
193
+ wedge = Wedge.from_file(args.tilt_angles)
194
+ wedge.weight_type = args.tilt_weighting
195
+ if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
196
+ wedge = WedgeReconstructed(
197
+ angles=wedge.angles, weight_wedge=args.tilt_weighting == "angle"
198
+ )
199
+ except FileNotFoundError:
200
+ tilt_step, create_continuous_wedge = None, True
201
+ tilt_start, tilt_stop = args.tilt_angles.split(",")
202
+ if ":" in tilt_stop:
203
+ create_continuous_wedge = False
204
+ tilt_stop, tilt_step = tilt_stop.split(":")
205
+ tilt_start, tilt_stop = float(tilt_start), float(tilt_stop)
206
+ tilt_angles = (tilt_start, tilt_stop)
207
+ if tilt_step is not None:
208
+ tilt_step = float(tilt_step)
209
+ tilt_angles = np.arange(
210
+ -tilt_start, tilt_stop + tilt_step, tilt_step
211
+ ).tolist()
212
+
213
+ if args.tilt_weighting is not None and tilt_step is None:
214
+ raise ValueError(
215
+ "Tilt weighting is not supported for continuous wedges."
216
+ )
217
+ if args.tilt_weighting not in ("angle", None):
218
+ raise ValueError(
219
+ "Tilt weighting schemes other than 'angle' or 'None' require "
220
+ "a specification of electron doses via --tilt_angles."
221
+ )
222
+
223
+ wedge = Wedge(
224
+ angles=tilt_angles,
225
+ opening_axis=args.wedge_axes[0],
226
+ tilt_axis=args.wedge_axes[1],
227
+ shape=None,
228
+ weight_type=None,
229
+ weights=np.ones_like(tilt_angles),
230
+ )
231
+ if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
232
+ wedge = WedgeReconstructed(
233
+ angles=tilt_angles,
234
+ weight_wedge=args.tilt_weighting == "angle",
235
+ create_continuous_wedge=create_continuous_wedge,
236
+ )
237
+
238
+ wedge.opening_axis = args.wedge_axes[0]
239
+ wedge.tilt_axis = args.wedge_axes[1]
240
+ wedge.sampling_rate = template.sampling_rate
241
+ template_filter.append(wedge)
242
+ if not isinstance(wedge, WedgeReconstructed):
243
+ template_filter.append(
244
+ ReconstructFromTilt(
245
+ reconstruction_filter=args.reconstruction_filter,
246
+ interpolation_order=args.reconstruction_interpolation_order,
247
+ )
248
+ )
249
+
250
+ if args.ctf_file is not None or args.defocus is not None:
251
+ from tme.preprocessing.tilt_series import CTF
252
+
253
+ needs_reconstruction = True
254
+ if args.ctf_file is not None:
255
+ ctf = CTF.from_file(args.ctf_file)
256
+ n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
257
+ if n_tilts_ctfs != n_tils_angles:
258
+ raise ValueError(
259
+ f"CTF file contains {n_tilts_ctfs} micrographs, but match_template "
260
+ f"recieved {n_tils_angles} tilt angles. Expected one angle "
261
+ "per micrograph."
262
+ )
263
+ ctf.angles = wedge.angles
264
+ ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
265
+ else:
266
+ needs_reconstruction = False
267
+ ctf = CTF(
268
+ defocus_x=[args.defocus],
269
+ phase_shift=[args.phase_shift],
270
+ defocus_y=None,
271
+ angles=[0],
272
+ shape=None,
273
+ return_real_fourier=True,
274
+ )
275
+ ctf.sampling_rate = template.sampling_rate
276
+ ctf.flip_phase = not args.no_flip_phase
277
+ ctf.amplitude_contrast = args.amplitude_contrast
278
+ ctf.spherical_aberration = args.spherical_aberration
279
+ ctf.acceleration_voltage = args.acceleration_voltage * 1e3
280
+ ctf.correct_defocus_gradient = args.correct_defocus_gradient
281
+
282
+ if not needs_reconstruction:
283
+ template_filter.append(ctf)
284
+ elif isinstance(template_filter[-1], ReconstructFromTilt):
285
+ template_filter.insert(-1, ctf)
286
+ else:
287
+ template_filter.insert(0, ctf)
288
+ template_filter.insert(
289
+ 1,
290
+ ReconstructFromTilt(
291
+ reconstruction_filter=args.reconstruction_filter,
292
+ interpolation_order=args.reconstruction_interpolation_order,
293
+ ),
294
+ )
295
+
296
+ if args.lowpass or args.highpass is not None:
297
+ lowpass, highpass = args.lowpass, args.highpass
298
+ if args.pass_format == "voxel":
299
+ if lowpass is not None:
300
+ lowpass = np.max(np.multiply(lowpass, template.sampling_rate))
301
+ if highpass is not None:
302
+ highpass = np.max(np.multiply(highpass, template.sampling_rate))
303
+ elif args.pass_format == "frequency":
304
+ if lowpass is not None:
305
+ lowpass = np.max(np.divide(template.sampling_rate, lowpass))
306
+ if highpass is not None:
307
+ highpass = np.max(np.divide(template.sampling_rate, highpass))
308
+
309
+ bandpass = BandPassFilter(
310
+ use_gaussian=args.no_pass_smooth,
311
+ lowpass=lowpass,
312
+ highpass=highpass,
313
+ sampling_rate=template.sampling_rate,
314
+ )
315
+ template_filter.append(bandpass)
316
+ target_filter.append(bandpass)
317
+
318
+ if args.whiten_spectrum:
319
+ whitening_filter = LinearWhiteningFilter()
320
+ template_filter.append(whitening_filter)
321
+ target_filter.append(whitening_filter)
322
+
323
+ needs_reconstruction = any(
324
+ [isinstance(t, ReconstructFromTilt) for t in template_filter]
325
+ )
326
+ if needs_reconstruction and args.reconstruction_filter is None:
327
+ warnings.warn(
328
+ "Consider using a --reconstruction_filter such as 'ramp' to avoid artifacts."
329
+ )
330
+
331
+ template_filter = Compose(template_filter) if len(template_filter) else None
332
+ target_filter = Compose(target_filter) if len(target_filter) else None
333
+
334
+ return template_filter, target_filter
335
+
336
+
153
337
  def parse_args():
154
338
  parser = argparse.ArgumentParser(description="Perform template matching.")
155
339
 
@@ -228,15 +412,65 @@ def parse_args():
228
412
  default=False,
229
413
  help="Perform peak calling instead of score aggregation.",
230
414
  )
231
- scoring_group.add_argument(
415
+
416
+ angular_group = parser.add_argument_group("Angular Sampling")
417
+ angular_exclusive = angular_group.add_mutually_exclusive_group(required=True)
418
+
419
+ angular_exclusive.add_argument(
232
420
  "-a",
233
421
  dest="angular_sampling",
234
422
  type=check_positive,
235
- default=40.0,
236
- help="Angular sampling rate for template matching. "
423
+ default=None,
424
+ help="Angular sampling rate using optimized rotational sets."
237
425
  "A lower number yields more rotations. Values >= 180 sample only the identity.",
238
426
  )
239
-
427
+ angular_exclusive.add_argument(
428
+ "--cone_angle",
429
+ dest="cone_angle",
430
+ type=check_positive,
431
+ default=None,
432
+ help="Half-angle of the cone to be sampled in degrees. Allows to sample a "
433
+ "narrow interval around a known orientation, e.g. for surface oversampling.",
434
+ )
435
+ angular_group.add_argument(
436
+ "--cone_sampling",
437
+ dest="cone_sampling",
438
+ type=check_positive,
439
+ default=None,
440
+ help="Sampling rate of the cone in degrees.",
441
+ )
442
+ angular_group.add_argument(
443
+ "--axis_angle",
444
+ dest="axis_angle",
445
+ type=check_positive,
446
+ default=360.0,
447
+ required=False,
448
+ help="Sampling angle along the z-axis of the cone. Defaults to 360.",
449
+ )
450
+ angular_group.add_argument(
451
+ "--axis_sampling",
452
+ dest="axis_sampling",
453
+ type=check_positive,
454
+ default=None,
455
+ required=False,
456
+ help="Sampling rate along the z-axis of the cone. Defaults to --cone_sampling.",
457
+ )
458
+ angular_group.add_argument(
459
+ "--axis_symmetry",
460
+ dest="axis_symmetry",
461
+ type=check_positive,
462
+ default=1,
463
+ required=False,
464
+ help="N-fold symmetry around z-axis of the cone.",
465
+ )
466
+ angular_group.add_argument(
467
+ "--no_use_optimized_set",
468
+ dest="no_use_optimized_set",
469
+ action="store_true",
470
+ default=False,
471
+ required=False,
472
+ help="Whether to use random uniform instead of optimized rotation sets.",
473
+ )
240
474
 
241
475
  computation_group = parser.add_argument_group("Computation")
242
476
  computation_group.add_argument(
@@ -282,21 +516,6 @@ def parse_args():
282
516
  help="Fraction of available memory that can be used. Defaults to 0.85 and is "
283
517
  "ignored if --ram is set",
284
518
  )
285
- computation_group.add_argument(
286
- "--use_mixed_precision",
287
- dest="use_mixed_precision",
288
- action="store_true",
289
- default=False,
290
- help="Use float16 for real values operations where possible.",
291
- )
292
- computation_group.add_argument(
293
- "--use_memmap",
294
- dest="use_memmap",
295
- action="store_true",
296
- default=False,
297
- help="Use memmaps to offload large data objects to disk. "
298
- "Particularly useful for large inputs in combination with --use_gpu.",
299
- )
300
519
  computation_group.add_argument(
301
520
  "--temp_directory",
302
521
  dest="temp_directory",
@@ -306,61 +525,160 @@ def parse_args():
306
525
 
307
526
  filter_group = parser.add_argument_group("Filters")
308
527
  filter_group.add_argument(
309
- "--gaussian_sigma",
310
- dest="gaussian_sigma",
528
+ "--lowpass",
529
+ dest="lowpass",
530
+ type=float,
531
+ required=False,
532
+ help="Resolution to lowpass filter template and target to in the same unit "
533
+ "as the sampling rate of template and target (typically Ångstrom).",
534
+ )
535
+ filter_group.add_argument(
536
+ "--highpass",
537
+ dest="highpass",
311
538
  type=float,
312
539
  required=False,
313
- help="Sigma parameter for Gaussian filtering the template.",
540
+ help="Resolution to highpass filter template and target to in the same unit "
541
+ "as the sampling rate of template and target (typically Ångstrom).",
314
542
  )
315
543
  filter_group.add_argument(
316
- "--bandpass_band",
317
- dest="bandpass_band",
544
+ "--no_pass_smooth",
545
+ dest="no_pass_smooth",
546
+ action="store_false",
547
+ default=True,
548
+ help="Whether a hard edge filter should be used for --lowpass and --highpass.",
549
+ )
550
+ filter_group.add_argument(
551
+ "--pass_format",
552
+ dest="pass_format",
318
553
  type=str,
319
554
  required=False,
320
- help="Comma separated start and stop frequency for bandpass filtering the"
321
- " template, e.g. 0.1, 0.5",
555
+ choices=["sampling_rate", "voxel", "frequency"],
556
+ help="How values passed to --lowpass and --highpass should be interpreted. "
557
+ "By default, they are assumed to be in units of sampling rate, e.g. Ångstrom.",
322
558
  )
323
559
  filter_group.add_argument(
324
- "--bandpass_smooth",
325
- dest="bandpass_smooth",
326
- type=float,
560
+ "--whiten_spectrum",
561
+ dest="whiten_spectrum",
562
+ action="store_true",
563
+ default=None,
564
+ help="Apply spectral whitening to template and target based on target spectrum.",
565
+ )
566
+ filter_group.add_argument(
567
+ "--wedge_axes",
568
+ dest="wedge_axes",
569
+ type=str,
327
570
  required=False,
328
571
  default=None,
329
- help="Sigma smooth parameter for the bandpass filter.",
572
+ help="Indices of wedge opening and tilt axis, e.g. 0,2 for a wedge that is open "
573
+ "in z-direction and tilted over the x axis.",
330
574
  )
331
575
  filter_group.add_argument(
332
- "--tilt_range",
333
- dest="tilt_range",
576
+ "--tilt_angles",
577
+ dest="tilt_angles",
334
578
  type=str,
335
579
  required=False,
336
- help="Comma separated start and stop stage tilt angle, e.g. '50,45'. Used"
337
- " to create a wedge mask to be applied to the template.",
580
+ default=None,
581
+ help="Path to a tab-separated file containing the column angles and optionally "
582
+ " weights, or comma separated start and stop stage tilt angle, e.g. 50,45, which "
583
+ " yields a continuous wedge mask. Alternatively, a tilt step size can be "
584
+ "specified like 50,45:5.0 to sample 5.0 degree tilt angle steps.",
338
585
  )
339
586
  filter_group.add_argument(
340
- "--tilt_step",
341
- dest="tilt_step",
342
- type=float,
587
+ "--tilt_weighting",
588
+ dest="tilt_weighting",
589
+ type=str,
343
590
  required=False,
591
+ choices=["angle", "relion", "grigorieff"],
344
592
  default=None,
345
- help="Step size between tilts. e.g. '5'. When set the wedge mask"
346
- " reflects the individual tilts, otherwise a continuous mask is used.",
593
+ help="Weighting scheme used to reweight individual tilts. Available options: "
594
+ "angle (cosine based weighting), "
595
+ "relion (relion formalism for wedge weighting) requires,"
596
+ "grigorieff (exposure filter as defined in Grant and Grigorieff 2015)."
597
+ "relion and grigorieff require electron doses in --tilt_angles weights column.",
347
598
  )
348
599
  filter_group.add_argument(
349
- "--wedge_axes",
350
- dest="wedge_axes",
600
+ "--reconstruction_filter",
601
+ dest="reconstruction_filter",
351
602
  type=str,
352
603
  required=False,
353
- default="0,2",
354
- help="Axis index of wedge opening and tilt axis, e.g. 0,2 for a wedge that is open in"
355
- " z and tilted over x.",
604
+ choices=["ram-lak", "ramp", "ramp-cont", "shepp-logan", "cosine", "hamming"],
605
+ default=None,
606
+ help="Filter applied when reconstructing (N+1)-D from N-D filters.",
356
607
  )
357
608
  filter_group.add_argument(
358
- "--wedge_smooth",
359
- dest="wedge_smooth",
609
+ "--reconstruction_interpolation_order",
610
+ dest="reconstruction_interpolation_order",
611
+ type=int,
612
+ default=1,
613
+ required=False,
614
+ help="Analogous to --interpolation_order but for reconstruction.",
615
+ )
616
+
617
+ ctf_group = parser.add_argument_group("Contrast Transfer Function")
618
+ ctf_group.add_argument(
619
+ "--ctf_file",
620
+ dest="ctf_file",
621
+ type=str,
622
+ required=False,
623
+ default=None,
624
+ help="Path to a file with CTF parameters from CTFFIND4. Each line will be "
625
+ "interpreted as tilt obtained at the angle specified in --tilt_angles. ",
626
+ )
627
+ ctf_group.add_argument(
628
+ "--defocus",
629
+ dest="defocus",
360
630
  type=float,
361
631
  required=False,
362
632
  default=None,
363
- help="Sigma smooth parameter for the wedge mask.",
633
+ help="Defocus in units of sampling rate (typically Ångstrom). "
634
+ "Superseded by --ctf_file.",
635
+ )
636
+ ctf_group.add_argument(
637
+ "--phase_shift",
638
+ dest="phase_shift",
639
+ type=float,
640
+ required=False,
641
+ default=0,
642
+ help="Phase shift in degrees. Superseded by --ctf_file.",
643
+ )
644
+ ctf_group.add_argument(
645
+ "--acceleration_voltage",
646
+ dest="acceleration_voltage",
647
+ type=float,
648
+ required=False,
649
+ default=300,
650
+ help="Acceleration voltage in kV, defaults to 300.",
651
+ )
652
+ ctf_group.add_argument(
653
+ "--spherical_aberration",
654
+ dest="spherical_aberration",
655
+ type=float,
656
+ required=False,
657
+ default=2.7e7,
658
+ help="Spherical aberration in units of sampling rate (typically Ångstrom).",
659
+ )
660
+ ctf_group.add_argument(
661
+ "--amplitude_contrast",
662
+ dest="amplitude_contrast",
663
+ type=float,
664
+ required=False,
665
+ default=0.07,
666
+ help="Amplitude contrast, defaults to 0.07.",
667
+ )
668
+ ctf_group.add_argument(
669
+ "--no_flip_phase",
670
+ dest="no_flip_phase",
671
+ action="store_false",
672
+ required=False,
673
+ help="Whether the phase of the computed CTF should not be flipped.",
674
+ )
675
+ ctf_group.add_argument(
676
+ "--correct_defocus_gradient",
677
+ dest="correct_defocus_gradient",
678
+ action="store_true",
679
+ required=False,
680
+ help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
681
+ "defocus gradients.",
364
682
  )
365
683
 
366
684
  performance_group = parser.add_argument_group("Performance")
@@ -412,6 +730,21 @@ def parse_args():
412
730
  help="Spline interpolation used for template rotations. If less than zero "
413
731
  "no interpolation is performed.",
414
732
  )
733
+ performance_group.add_argument(
734
+ "--use_mixed_precision",
735
+ dest="use_mixed_precision",
736
+ action="store_true",
737
+ default=False,
738
+ help="Use float16 for real values operations where possible.",
739
+ )
740
+ performance_group.add_argument(
741
+ "--use_memmap",
742
+ dest="use_memmap",
743
+ action="store_true",
744
+ default=False,
745
+ help="Use memmaps to offload large data objects to disk. "
746
+ "Particularly useful for large inputs in combination with --use_gpu.",
747
+ )
415
748
 
416
749
  analyzer_group = parser.add_argument_group("Analyzer")
417
750
  analyzer_group.add_argument(
@@ -424,6 +757,7 @@ def parse_args():
424
757
  )
425
758
 
426
759
  args = parser.parse_args()
760
+ args.version = __version__
427
761
 
428
762
  if args.interpolation_order < 0:
429
763
  args.interpolation_order = None
@@ -460,6 +794,21 @@ def parse_args():
460
794
  int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
461
795
  ]
462
796
 
797
+ if args.tilt_angles is not None:
798
+ if args.wedge_axes is None:
799
+ raise ValueError("Need to specify --wedge_axes when --tilt_angles is set.")
800
+ if not exists(args.tilt_angles):
801
+ try:
802
+ float(args.tilt_angles.split(",")[0])
803
+ except ValueError:
804
+ raise ValueError(f"{args.tilt_angles} is not a file nor a range.")
805
+
806
+ if args.ctf_file is not None and args.tilt_angles is None:
807
+ raise ValueError("Need to specify --tilt_angles when --ctf_file is set.")
808
+
809
+ if args.wedge_axes is not None:
810
+ args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
811
+
463
812
  return args
464
813
 
465
814
 
@@ -477,12 +826,13 @@ def main():
477
826
  sampling_rate=target.sampling_rate,
478
827
  )
479
828
 
480
- if not np.allclose(target.sampling_rate, template.sampling_rate):
481
- print(
482
- f"Resampling template to {target.sampling_rate}. "
483
- "Consider providing a template with the same sampling rate as the target."
484
- )
485
- template = template.resample(target.sampling_rate, order=3)
829
+ if target.sampling_rate.size == template.sampling_rate.size:
830
+ if not np.allclose(target.sampling_rate, template.sampling_rate):
831
+ print(
832
+ f"Resampling template to {target.sampling_rate}. "
833
+ "Consider providing a template with the same sampling rate as the target."
834
+ )
835
+ template = template.resample(target.sampling_rate, order=3)
486
836
 
487
837
  template_mask = load_and_validate_mask(
488
838
  mask_target=template, mask_path=args.template_mask
@@ -536,51 +886,6 @@ def main():
536
886
  },
537
887
  )
538
888
 
539
- template_filter = {}
540
- if args.gaussian_sigma is not None:
541
- template.data = Preprocessor().gaussian_filter(
542
- sigma=args.gaussian_sigma, template=template.data
543
- )
544
-
545
- if args.bandpass_band is not None:
546
- bandpass_start, bandpass_stop = [
547
- float(x) for x in args.bandpass_band.split(",")
548
- ]
549
- if args.bandpass_smooth is None:
550
- args.bandpass_smooth = 0
551
-
552
- template_filter["bandpass_mask"] = {
553
- "minimum_frequency": bandpass_start,
554
- "maximum_frequency": bandpass_stop,
555
- "gaussian_sigma": args.bandpass_smooth,
556
- }
557
-
558
- if args.tilt_range is not None:
559
- args.wedge_smooth if args.wedge_smooth is not None else 0
560
- tilt_start, tilt_stop = [float(x) for x in args.tilt_range.split(",")]
561
- opening_axis, tilt_axis = [int(x) for x in args.wedge_axes.split(",")]
562
-
563
- if args.tilt_step is not None:
564
- template_filter["step_wedge_mask"] = {
565
- "start_tilt": tilt_start,
566
- "stop_tilt": tilt_stop,
567
- "tilt_step": args.tilt_step,
568
- "sigma": args.wedge_smooth,
569
- "opening_axis": opening_axis,
570
- "tilt_axis": tilt_axis,
571
- "omit_negative_frequencies": True,
572
- }
573
- else:
574
- template_filter["continuous_wedge_mask"] = {
575
- "start_tilt": tilt_start,
576
- "stop_tilt": tilt_stop,
577
- "tilt_axis": tilt_axis,
578
- "opening_axis": opening_axis,
579
- "infinite_plane": True,
580
- "sigma": args.wedge_smooth,
581
- "omit_negative_frequencies": True,
582
- }
583
-
584
889
  if template_mask is None:
585
890
  template_mask = template.empty
586
891
  if not args.no_centering:
@@ -660,31 +965,46 @@ def main():
660
965
  if args.memory is None:
661
966
  args.memory = int(args.memory_scaling * available_memory)
662
967
 
663
- target_padding = np.zeros_like(template.shape)
664
- if args.pad_target_edges:
665
- target_padding = template.shape
968
+ callback_class = MaxScoreOverRotations
969
+ if args.peak_calling:
970
+ callback_class = PeakCallerMaximumFilter
971
+
972
+ matching_data = MatchingData(
973
+ target=target,
974
+ template=template.data,
975
+ target_mask=target_mask,
976
+ template_mask=template_mask,
977
+ invert_target=args.invert_target_contrast,
978
+ rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
979
+ )
666
980
 
667
- template_box = template.shape
981
+ template_filter, target_filter = setup_filter(args, template, target)
982
+ matching_data.template_filter = template_filter
983
+ matching_data.target_filter = target_filter
984
+
985
+ template_box = matching_data._output_template_shape
668
986
  if not args.pad_fourier:
669
987
  template_box = np.ones(len(template_box), dtype=int)
670
988
 
671
- callback_class = MaxScoreOverRotations
672
- if args.peak_calling:
673
- callback_class = PeakCallerMaximumFilter
989
+ target_padding = np.zeros(
990
+ (backend.size(matching_data._output_template_shape)), dtype=int
991
+ )
992
+ if args.pad_target_edges:
993
+ target_padding = matching_data._output_template_shape
674
994
 
675
995
  splits, schedule = compute_parallelization_schedule(
676
996
  shape1=target.shape,
677
- shape2=template_box,
678
- shape1_padding=target_padding,
997
+ shape2=tuple(int(x) for x in template_box),
998
+ shape1_padding=tuple(int(x) for x in target_padding),
679
999
  max_cores=args.cores,
680
1000
  max_ram=args.memory,
681
1001
  split_only_outer=args.use_gpu,
682
1002
  matching_method=args.score,
683
1003
  analyzer_method=callback_class.__name__,
684
1004
  backend=backend._backend_name,
685
- float_nbytes=backend.datatype_bytes(backend._default_dtype),
1005
+ float_nbytes=backend.datatype_bytes(backend._float_dtype),
686
1006
  complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
687
- integer_nbytes=backend.datatype_bytes(backend._default_dtype_int),
1007
+ integer_nbytes=backend.datatype_bytes(backend._int_dtype),
688
1008
  )
689
1009
 
690
1010
  if splits is None:
@@ -694,29 +1014,7 @@ def main():
694
1014
  )
695
1015
  exit(-1)
696
1016
 
697
- analyzer_args = {
698
- "score_threshold": args.score_threshold,
699
- "number_of_peaks": 1000,
700
- "convolution_mode": "valid",
701
- "use_memmap": args.use_memmap,
702
- }
703
-
704
1017
  matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
705
- matching_data = MatchingData(target=target, template=template.data)
706
- matching_data.rotations = get_rotation_matrices(
707
- angular_sampling=args.angular_sampling, dim=target.data.ndim
708
- )
709
- if args.angular_sampling >= 180:
710
- ndim = target.data.ndim
711
- matching_data.rotations = np.eye(ndim).reshape(1, ndim, ndim)
712
-
713
- matching_data.template_filter = template_filter
714
- matching_data._invert_target = args.invert_target_contrast
715
- if target_mask is not None:
716
- matching_data.target_mask = target_mask
717
- if template_mask is not None:
718
- matching_data.template_mask = template_mask.data
719
-
720
1018
  n_splits = np.prod(list(splits.values()))
721
1019
  target_split = ", ".join(
722
1020
  [":".join([str(x) for x in axis]) for axis in splits.items()]
@@ -746,10 +1044,45 @@ def main():
746
1044
  label_width=max(len(key) for key in options.keys()) + 2,
747
1045
  )
748
1046
 
749
- options = {"Analyzer": callback_class, **analyzer_args}
1047
+ filter_args = {
1048
+ "Lowpass": args.lowpass,
1049
+ "Highpass": args.highpass,
1050
+ "Smooth Pass": args.no_pass_smooth,
1051
+ "Pass Format": args.pass_format,
1052
+ "Spectral Whitening": args.whiten_spectrum,
1053
+ "Wedge Axes": args.wedge_axes,
1054
+ "Tilt Angles": args.tilt_angles,
1055
+ "Tilt Weighting": args.tilt_weighting,
1056
+ "Reconstruction Filter": args.reconstruction_filter,
1057
+ }
1058
+ if args.ctf_file is not None or args.defocus is not None:
1059
+ filter_args["CTF File"] = args.ctf_file
1060
+ filter_args["Defocus"] = args.defocus
1061
+ filter_args["Phase Shift"] = args.phase_shift
1062
+ filter_args["No Flip Phase"] = args.no_flip_phase
1063
+ filter_args["Acceleration Voltage"] = args.acceleration_voltage
1064
+ filter_args["Spherical Aberration"] = args.spherical_aberration
1065
+ filter_args["Amplitude Contrast"] = args.amplitude_contrast
1066
+ filter_args["Correct Defocus"] = args.correct_defocus_gradient
1067
+
1068
+ filter_args = {k: v for k, v in filter_args.items() if v is not None}
1069
+ if len(filter_args):
1070
+ print_block(
1071
+ name="Filters",
1072
+ data=filter_args,
1073
+ label_width=max(len(key) for key in options.keys()) + 2,
1074
+ )
1075
+
1076
+ analyzer_args = {
1077
+ "score_threshold": args.score_threshold,
1078
+ "number_of_peaks": 1000,
1079
+ "convolution_mode": "valid",
1080
+ "use_memmap": args.use_memmap,
1081
+ }
1082
+ analyzer_args = {"Analyzer": callback_class, **analyzer_args}
750
1083
  print_block(
751
1084
  name="Score Analysis Options",
752
- data=options,
1085
+ data=analyzer_args,
753
1086
  label_width=max(len(key) for key in options.keys()) + 2,
754
1087
  )
755
1088
  print("\n" + "-" * 80)
@@ -780,16 +1113,16 @@ def main():
780
1113
  candidates[0] *= target_mask.data
781
1114
  with warnings.catch_warnings():
782
1115
  warnings.simplefilter("ignore", category=UserWarning)
1116
+ nbytes = backend.datatype_bytes(backend._float_dtype)
1117
+ dtype = np.float32 if nbytes == 4 else np.float16
1118
+ rot_dim = matching_data.rotations.shape[1]
783
1119
  candidates[3] = {
784
1120
  x: euler_from_rotationmatrix(
785
- np.frombuffer(i, dtype=matching_data.rotations.dtype).reshape(
786
- candidates[0].ndim, candidates[0].ndim
787
- )
1121
+ np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
788
1122
  )
789
1123
  for i, x in candidates[3].items()
790
1124
  }
791
-
792
- candidates.append((target.origin, template.origin, target.sampling_rate, args))
1125
+ candidates.append((target.origin, template.origin, template.sampling_rate, args))
793
1126
  write_pickle(data=candidates, filename=args.output)
794
1127
 
795
1128
  runtime = time() - start