pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/match_template.py +473 -140
  2. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
  3. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
  4. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/METADATA +2 -2
  5. pytme-0.2.1.dist-info/RECORD +73 -0
  6. scripts/extract_candidates.py +117 -85
  7. scripts/match_template.py +473 -140
  8. scripts/match_template_filters.py +458 -169
  9. scripts/postprocess.py +107 -49
  10. scripts/preprocessor_gui.py +4 -1
  11. scripts/refine_matches.py +364 -160
  12. tme/__version__.py +1 -1
  13. tme/analyzer.py +278 -148
  14. tme/backends/__init__.py +1 -0
  15. tme/backends/cupy_backend.py +20 -13
  16. tme/backends/jax_backend.py +218 -0
  17. tme/backends/matching_backend.py +25 -10
  18. tme/backends/mlx_backend.py +13 -9
  19. tme/backends/npfftw_backend.py +22 -12
  20. tme/backends/pytorch_backend.py +20 -9
  21. tme/density.py +85 -64
  22. tme/extensions.cpython-311-darwin.so +0 -0
  23. tme/matching_data.py +86 -60
  24. tme/matching_exhaustive.py +245 -166
  25. tme/matching_optimization.py +137 -69
  26. tme/matching_utils.py +1 -1
  27. tme/orientations.py +175 -55
  28. tme/preprocessing/__init__.py +2 -0
  29. tme/preprocessing/_utils.py +188 -0
  30. tme/preprocessing/composable_filter.py +31 -0
  31. tme/preprocessing/compose.py +51 -0
  32. tme/preprocessing/frequency_filters.py +378 -0
  33. tme/preprocessing/tilt_series.py +1017 -0
  34. tme/preprocessor.py +17 -7
  35. tme/structure.py +4 -1
  36. pytme-0.2.0b0.dist-info/RECORD +0 -66
  37. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
  38. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
  39. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
  40. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
  41. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
  42. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
@@ -13,13 +13,14 @@ from sys import exit
13
13
  from time import time
14
14
  from typing import Tuple
15
15
  from copy import deepcopy
16
- from os.path import abspath
16
+ from os.path import abspath, exists
17
17
 
18
18
  import numpy as np
19
19
 
20
20
  from tme import Density, __version__
21
21
  from tme.matching_utils import (
22
22
  get_rotation_matrices,
23
+ get_rotations_around_vector,
23
24
  compute_parallelization_schedule,
24
25
  euler_from_rotationmatrix,
25
26
  scramble_phases,
@@ -32,8 +33,8 @@ from tme.analyzer import (
32
33
  MaxScoreOverRotations,
33
34
  PeakCallerMaximumFilter,
34
35
  )
35
- from tme.preprocessing import Compose
36
36
  from tme.backends import backend
37
+ from tme.preprocessing import Compose
37
38
 
38
39
 
39
40
  def get_func_fullname(func) -> str:
@@ -152,6 +153,187 @@ def crop_data(data: Density, cutoff: float, data_mask: Density = None) -> bool:
152
153
  return True
153
154
 
154
155
 
156
+ def parse_rotation_logic(args, ndim):
157
+ if args.angular_sampling is not None:
158
+ rotations = get_rotation_matrices(
159
+ angular_sampling=args.angular_sampling,
160
+ dim=ndim,
161
+ use_optimized_set=not args.no_use_optimized_set,
162
+ )
163
+ if args.angular_sampling >= 180:
164
+ rotations = np.eye(ndim).reshape(1, ndim, ndim)
165
+ return rotations
166
+
167
+ if args.axis_sampling is None:
168
+ args.axis_sampling = args.cone_sampling
169
+
170
+ rotations = get_rotations_around_vector(
171
+ cone_angle=args.cone_angle,
172
+ cone_sampling=args.cone_sampling,
173
+ axis_angle=args.axis_angle,
174
+ axis_sampling=args.axis_sampling,
175
+ n_symmetry=args.axis_symmetry,
176
+ )
177
+ return rotations
178
+
179
+
180
+ # TODO: Think about whether wedge mask should also be added to target
181
+ # For now leave it at the cost of incorrect upper bound on the scores
182
+ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
183
+ from tme.preprocessing import LinearWhiteningFilter, BandPassFilter
184
+ from tme.preprocessing.tilt_series import (
185
+ Wedge,
186
+ WedgeReconstructed,
187
+ ReconstructFromTilt,
188
+ )
189
+
190
+ template_filter, target_filter = [], []
191
+ if args.tilt_angles is not None:
192
+ try:
193
+ wedge = Wedge.from_file(args.tilt_angles)
194
+ wedge.weight_type = args.tilt_weighting
195
+ if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
196
+ wedge = WedgeReconstructed(
197
+ angles=wedge.angles, weight_wedge=args.tilt_weighting == "angle"
198
+ )
199
+ except FileNotFoundError:
200
+ tilt_step, create_continuous_wedge = None, True
201
+ tilt_start, tilt_stop = args.tilt_angles.split(",")
202
+ if ":" in tilt_stop:
203
+ create_continuous_wedge = False
204
+ tilt_stop, tilt_step = tilt_stop.split(":")
205
+ tilt_start, tilt_stop = float(tilt_start), float(tilt_stop)
206
+ tilt_angles = (tilt_start, tilt_stop)
207
+ if tilt_step is not None:
208
+ tilt_step = float(tilt_step)
209
+ tilt_angles = np.arange(
210
+ -tilt_start, tilt_stop + tilt_step, tilt_step
211
+ ).tolist()
212
+
213
+ if args.tilt_weighting is not None and tilt_step is None:
214
+ raise ValueError(
215
+ "Tilt weighting is not supported for continuous wedges."
216
+ )
217
+ if args.tilt_weighting not in ("angle", None):
218
+ raise ValueError(
219
+ "Tilt weighting schemes other than 'angle' or 'None' require "
220
+ "a specification of electron doses via --tilt_angles."
221
+ )
222
+
223
+ wedge = Wedge(
224
+ angles=tilt_angles,
225
+ opening_axis=args.wedge_axes[0],
226
+ tilt_axis=args.wedge_axes[1],
227
+ shape=None,
228
+ weight_type=None,
229
+ weights=np.ones_like(tilt_angles),
230
+ )
231
+ if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
232
+ wedge = WedgeReconstructed(
233
+ angles=tilt_angles,
234
+ weight_wedge=args.tilt_weighting == "angle",
235
+ create_continuous_wedge=create_continuous_wedge,
236
+ )
237
+
238
+ wedge.opening_axis = args.wedge_axes[0]
239
+ wedge.tilt_axis = args.wedge_axes[1]
240
+ wedge.sampling_rate = template.sampling_rate
241
+ template_filter.append(wedge)
242
+ if not isinstance(wedge, WedgeReconstructed):
243
+ template_filter.append(
244
+ ReconstructFromTilt(
245
+ reconstruction_filter=args.reconstruction_filter,
246
+ interpolation_order=args.reconstruction_interpolation_order,
247
+ )
248
+ )
249
+
250
+ if args.ctf_file is not None or args.defocus is not None:
251
+ from tme.preprocessing.tilt_series import CTF
252
+
253
+ needs_reconstruction = True
254
+ if args.ctf_file is not None:
255
+ ctf = CTF.from_file(args.ctf_file)
256
+ n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
257
+ if n_tilts_ctfs != n_tils_angles:
258
+ raise ValueError(
259
+ f"CTF file contains {n_tilts_ctfs} micrographs, but match_template "
260
+ f"recieved {n_tils_angles} tilt angles. Expected one angle "
261
+ "per micrograph."
262
+ )
263
+ ctf.angles = wedge.angles
264
+ ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
265
+ else:
266
+ needs_reconstruction = False
267
+ ctf = CTF(
268
+ defocus_x=[args.defocus],
269
+ phase_shift=[args.phase_shift],
270
+ defocus_y=None,
271
+ angles=[0],
272
+ shape=None,
273
+ return_real_fourier=True,
274
+ )
275
+ ctf.sampling_rate = template.sampling_rate
276
+ ctf.flip_phase = not args.no_flip_phase
277
+ ctf.amplitude_contrast = args.amplitude_contrast
278
+ ctf.spherical_aberration = args.spherical_aberration
279
+ ctf.acceleration_voltage = args.acceleration_voltage * 1e3
280
+ ctf.correct_defocus_gradient = args.correct_defocus_gradient
281
+
282
+ if not needs_reconstruction:
283
+ template_filter.append(ctf)
284
+ elif isinstance(template_filter[-1], ReconstructFromTilt):
285
+ template_filter.insert(-1, ctf)
286
+ else:
287
+ template_filter.insert(0, ctf)
288
+ template_filter.insert(
289
+ 1,
290
+ ReconstructFromTilt(
291
+ reconstruction_filter=args.reconstruction_filter,
292
+ interpolation_order=args.reconstruction_interpolation_order,
293
+ ),
294
+ )
295
+
296
+ if args.lowpass or args.highpass is not None:
297
+ lowpass, highpass = args.lowpass, args.highpass
298
+ if args.pass_format == "voxel":
299
+ if lowpass is not None:
300
+ lowpass = np.max(np.multiply(lowpass, template.sampling_rate))
301
+ if highpass is not None:
302
+ highpass = np.max(np.multiply(highpass, template.sampling_rate))
303
+ elif args.pass_format == "frequency":
304
+ if lowpass is not None:
305
+ lowpass = np.max(np.divide(template.sampling_rate, lowpass))
306
+ if highpass is not None:
307
+ highpass = np.max(np.divide(template.sampling_rate, highpass))
308
+
309
+ bandpass = BandPassFilter(
310
+ use_gaussian=args.no_pass_smooth,
311
+ lowpass=lowpass,
312
+ highpass=highpass,
313
+ sampling_rate=template.sampling_rate,
314
+ )
315
+ template_filter.append(bandpass)
316
+ target_filter.append(bandpass)
317
+
318
+ if args.whiten_spectrum:
319
+ whitening_filter = LinearWhiteningFilter()
320
+ template_filter.append(whitening_filter)
321
+ target_filter.append(whitening_filter)
322
+
323
+ needs_reconstruction = any(
324
+ [isinstance(t, ReconstructFromTilt) for t in template_filter]
325
+ )
326
+ if needs_reconstruction and args.reconstruction_filter is None:
327
+ warnings.warn(
328
+ "Consider using a --reconstruction_filter such as 'ramp' to avoid artifacts."
329
+ )
330
+
331
+ template_filter = Compose(template_filter) if len(template_filter) else None
332
+ target_filter = Compose(target_filter) if len(target_filter) else None
333
+
334
+ return template_filter, target_filter
335
+
336
+
155
337
  def parse_args():
156
338
  parser = argparse.ArgumentParser(description="Perform template matching.")
157
339
 
@@ -224,13 +406,71 @@ def parse_args():
224
406
  help="Template matching scoring function.",
225
407
  )
226
408
  scoring_group.add_argument(
409
+ "-p",
410
+ dest="peak_calling",
411
+ action="store_true",
412
+ default=False,
413
+ help="Perform peak calling instead of score aggregation.",
414
+ )
415
+
416
+ angular_group = parser.add_argument_group("Angular Sampling")
417
+ angular_exclusive = angular_group.add_mutually_exclusive_group(required=True)
418
+
419
+ angular_exclusive.add_argument(
227
420
  "-a",
228
421
  dest="angular_sampling",
229
422
  type=check_positive,
230
- default=40.0,
231
- help="Angular sampling rate for template matching. "
423
+ default=None,
424
+ help="Angular sampling rate using optimized rotational sets."
232
425
  "A lower number yields more rotations. Values >= 180 sample only the identity.",
233
426
  )
427
+ angular_exclusive.add_argument(
428
+ "--cone_angle",
429
+ dest="cone_angle",
430
+ type=check_positive,
431
+ default=None,
432
+ help="Half-angle of the cone to be sampled in degrees. Allows to sample a "
433
+ "narrow interval around a known orientation, e.g. for surface oversampling.",
434
+ )
435
+ angular_group.add_argument(
436
+ "--cone_sampling",
437
+ dest="cone_sampling",
438
+ type=check_positive,
439
+ default=None,
440
+ help="Sampling rate of the cone in degrees.",
441
+ )
442
+ angular_group.add_argument(
443
+ "--axis_angle",
444
+ dest="axis_angle",
445
+ type=check_positive,
446
+ default=360.0,
447
+ required=False,
448
+ help="Sampling angle along the z-axis of the cone. Defaults to 360.",
449
+ )
450
+ angular_group.add_argument(
451
+ "--axis_sampling",
452
+ dest="axis_sampling",
453
+ type=check_positive,
454
+ default=None,
455
+ required=False,
456
+ help="Sampling rate along the z-axis of the cone. Defaults to --cone_sampling.",
457
+ )
458
+ angular_group.add_argument(
459
+ "--axis_symmetry",
460
+ dest="axis_symmetry",
461
+ type=check_positive,
462
+ default=1,
463
+ required=False,
464
+ help="N-fold symmetry around z-axis of the cone.",
465
+ )
466
+ angular_group.add_argument(
467
+ "--no_use_optimized_set",
468
+ dest="no_use_optimized_set",
469
+ action="store_true",
470
+ default=False,
471
+ required=False,
472
+ help="Whether to use random uniform instead of optimized rotation sets.",
473
+ )
234
474
 
235
475
  computation_group = parser.add_argument_group("Computation")
236
476
  computation_group.add_argument(
@@ -276,21 +516,6 @@ def parse_args():
276
516
  help="Fraction of available memory that can be used. Defaults to 0.85 and is "
277
517
  "ignored if --ram is set",
278
518
  )
279
- computation_group.add_argument(
280
- "--use_mixed_precision",
281
- dest="use_mixed_precision",
282
- action="store_true",
283
- default=False,
284
- help="Use float16 for real values operations where possible.",
285
- )
286
- computation_group.add_argument(
287
- "--use_memmap",
288
- dest="use_memmap",
289
- action="store_true",
290
- default=False,
291
- help="Use memmaps to offload large data objects to disk. "
292
- "Particularly useful for large inputs in combination with --use_gpu.",
293
- )
294
519
  computation_group.add_argument(
295
520
  "--temp_directory",
296
521
  dest="temp_directory",
@@ -315,11 +540,27 @@ def parse_args():
315
540
  help="Resolution to highpass filter template and target to in the same unit "
316
541
  "as the sampling rate of template and target (typically Ångstrom).",
317
542
  )
543
+ filter_group.add_argument(
544
+ "--no_pass_smooth",
545
+ dest="no_pass_smooth",
546
+ action="store_false",
547
+ default=True,
548
+ help="Whether a hard edge filter should be used for --lowpass and --highpass.",
549
+ )
550
+ filter_group.add_argument(
551
+ "--pass_format",
552
+ dest="pass_format",
553
+ type=str,
554
+ required=False,
555
+ choices=["sampling_rate", "voxel", "frequency"],
556
+ help="How values passed to --lowpass and --highpass should be interpreted. "
557
+ "By default, they are assumed to be in units of sampling rate, e.g. Ångstrom.",
558
+ )
318
559
  filter_group.add_argument(
319
560
  "--whiten_spectrum",
320
561
  dest="whiten_spectrum",
321
562
  action="store_true",
322
- default=False,
563
+ default=None,
323
564
  help="Apply spectral whitening to template and target based on target spectrum.",
324
565
  )
325
566
  filter_group.add_argument(
@@ -327,7 +568,7 @@ def parse_args():
327
568
  dest="wedge_axes",
328
569
  type=str,
329
570
  required=False,
330
- default="0,2",
571
+ default=None,
331
572
  help="Indices of wedge opening and tilt axis, e.g. 0,2 for a wedge that is open "
332
573
  "in z-direction and tilted over the x axis.",
333
574
  )
@@ -337,10 +578,10 @@ def parse_args():
337
578
  type=str,
338
579
  required=False,
339
580
  default=None,
340
- help="Path to a file with angles and corresponding doses, or comma separated "
341
- "start and stop stage tilt angle, e.g. 50,45, which yields a continuous wedge "
342
- "mask. Alternatively, a tilt step size can be specified like 50,45:5.0 to "
343
- "sample 5.0 degree tilt angle steps.",
581
+ help="Path to a tab-separated file containing the column angles and optionally "
582
+ " weights, or comma separated start and stop stage tilt angle, e.g. 50,45, which "
583
+ " yields a continuous wedge mask. Alternatively, a tilt step size can be "
584
+ "specified like 50,45:5.0 to sample 5.0 degree tilt angle steps.",
344
585
  )
345
586
  filter_group.add_argument(
346
587
  "--tilt_weighting",
@@ -351,17 +592,93 @@ def parse_args():
351
592
  default=None,
352
593
  help="Weighting scheme used to reweight individual tilts. Available options: "
353
594
  "angle (cosine based weighting), "
354
- "relion (relion formalism for wedge weighting ),"
595
+ "relion (relion formalism for wedge weighting) requires,"
355
596
  "grigorieff (exposure filter as defined in Grant and Grigorieff 2015)."
356
- "",
597
+ "relion and grigorieff require electron doses in --tilt_angles weights column.",
357
598
  )
358
599
  filter_group.add_argument(
600
+ "--reconstruction_filter",
601
+ dest="reconstruction_filter",
602
+ type=str,
603
+ required=False,
604
+ choices=["ram-lak", "ramp", "ramp-cont", "shepp-logan", "cosine", "hamming"],
605
+ default=None,
606
+ help="Filter applied when reconstructing (N+1)-D from N-D filters.",
607
+ )
608
+ filter_group.add_argument(
609
+ "--reconstruction_interpolation_order",
610
+ dest="reconstruction_interpolation_order",
611
+ type=int,
612
+ default=1,
613
+ required=False,
614
+ help="Analogous to --interpolation_order but for reconstruction.",
615
+ )
616
+
617
+ ctf_group = parser.add_argument_group("Contrast Transfer Function")
618
+ ctf_group.add_argument(
359
619
  "--ctf_file",
360
620
  dest="ctf_file",
361
621
  type=str,
362
622
  required=False,
363
623
  default=None,
364
- help="Path to a file with CTF parameters.",
624
+ help="Path to a file with CTF parameters from CTFFIND4. Each line will be "
625
+ "interpreted as tilt obtained at the angle specified in --tilt_angles. ",
626
+ )
627
+ ctf_group.add_argument(
628
+ "--defocus",
629
+ dest="defocus",
630
+ type=float,
631
+ required=False,
632
+ default=None,
633
+ help="Defocus in units of sampling rate (typically Ångstrom). "
634
+ "Superseded by --ctf_file.",
635
+ )
636
+ ctf_group.add_argument(
637
+ "--phase_shift",
638
+ dest="phase_shift",
639
+ type=float,
640
+ required=False,
641
+ default=0,
642
+ help="Phase shift in degrees. Superseded by --ctf_file.",
643
+ )
644
+ ctf_group.add_argument(
645
+ "--acceleration_voltage",
646
+ dest="acceleration_voltage",
647
+ type=float,
648
+ required=False,
649
+ default=300,
650
+ help="Acceleration voltage in kV, defaults to 300.",
651
+ )
652
+ ctf_group.add_argument(
653
+ "--spherical_aberration",
654
+ dest="spherical_aberration",
655
+ type=float,
656
+ required=False,
657
+ default=2.7e7,
658
+ help="Spherical aberration in units of sampling rate (typically Ångstrom).",
659
+ )
660
+ ctf_group.add_argument(
661
+ "--amplitude_contrast",
662
+ dest="amplitude_contrast",
663
+ type=float,
664
+ required=False,
665
+ default=0.07,
666
+ help="Amplitude contrast, defaults to 0.07.",
667
+ )
668
+ ctf_group.add_argument(
669
+ "--no_flip_phase",
670
+ dest="no_flip_phase",
671
+ action="store_false",
672
+ required=False,
673
+ help="Whether the phase of the computed CTF should not be flipped.",
674
+ )
675
+ ctf_group.add_argument(
676
+ "--correct_defocus_gradient",
677
+ dest="correct_defocus_gradient",
678
+ action="store_true",
679
+ required=False,
680
+ help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
681
+ "defocus gradients.",
365
682
  )
366
683
 
367
684
  performance_group = parser.add_argument_group("Performance")
@@ -413,6 +730,21 @@ def parse_args():
413
730
  help="Spline interpolation used for template rotations. If less than zero "
414
731
  "no interpolation is performed.",
415
732
  )
733
+ performance_group.add_argument(
734
+ "--use_mixed_precision",
735
+ dest="use_mixed_precision",
736
+ action="store_true",
737
+ default=False,
738
+ help="Use float16 for real values operations where possible.",
739
+ )
740
+ performance_group.add_argument(
741
+ "--use_memmap",
742
+ dest="use_memmap",
743
+ action="store_true",
744
+ default=False,
745
+ help="Use memmaps to offload large data objects to disk. "
746
+ "Particularly useful for large inputs in combination with --use_gpu.",
747
+ )
416
748
 
417
749
  analyzer_group = parser.add_argument_group("Analyzer")
418
750
  analyzer_group.add_argument(
@@ -423,14 +755,9 @@ def parse_args():
423
755
  default=0,
424
756
  help="Minimum template matching scores to consider for analysis.",
425
757
  )
426
- analyzer_group.add_argument(
427
- "-p",
428
- dest="peak_calling",
429
- action="store_true",
430
- default=False,
431
- help="Perform peak calling instead of score aggregation.",
432
- )
758
+
433
759
  args = parser.parse_args()
760
+ args.version = __version__
434
761
 
435
762
  if args.interpolation_order < 0:
436
763
  args.interpolation_order = None
@@ -467,94 +794,22 @@ def parse_args():
467
794
  int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
468
795
  ]
469
796
 
470
- if args.wedge_axes is not None:
471
- args.wedge_axes = [int(x) for x in args.wedge_axes.split(",")]
472
-
473
- if args.tilt_angles is not None and args.wedge_axes is None:
474
- raise ValueError("Wedge axes have to be specified with tilt angles.")
475
-
476
- if args.ctf_file is not None and args.wedge_axes is None:
477
- raise ValueError("Wedge axes have to be specified with CTF parameters.")
478
- if args.ctf_file is not None and args.tilt_angles is None:
479
- raise ValueError("Angles have to be specified with CTF parameters.")
480
-
481
- return args
482
-
483
-
484
- def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
485
- from tme.preprocessing import LinearWhiteningFilter, BandPassFilter
486
-
487
- template_filter, target_filter = [], []
488
797
  if args.tilt_angles is not None:
489
- from tme.preprocessing.tilt_series import (
490
- Wedge,
491
- WedgeReconstructed,
492
- ReconstructFromTilt,
493
- )
494
-
495
- try:
496
- wedge = Wedge.from_file(args.tilt_angles)
497
- wedge.weight_type = args.tilt_weighting
498
- except FileNotFoundError:
499
- tilt_step = None
500
- tilt_start, tilt_stop = args.tilt_angles.split(",")
501
- if ":" in tilt_stop:
502
- tilt_stop, tilt_step = tilt_stop.split(":")
503
- tilt_start, tilt_stop = float(tilt_start), float(tilt_stop)
504
- tilt_angles = None
505
- if tilt_step is not None:
506
- tilt_step = float(tilt_step)
507
- tilt_angles = np.arange(
508
- -tilt_start, tilt_stop + tilt_step, tilt_step
509
- ).tolist()
510
- wedge = WedgeReconstructed(
511
- angles=tilt_angles,
512
- start_tilt=tilt_start,
513
- stop_tilt=tilt_stop,
514
- )
515
- wedge.opening_axis = args.wedge_axes[0]
516
- wedge.tilt_axis = args.wedge_axes[1]
517
- wedge.sampling_rate = template.sampling_rate
518
- template_filter.append(wedge)
519
- if not isinstance(wedge, WedgeReconstructed):
520
- template_filter.append(ReconstructFromTilt())
798
+ if args.wedge_axes is None:
799
+ raise ValueError("Need to specify --wedge_axes when --tilt_angles is set.")
800
+ if not exists(args.tilt_angles):
801
+ try:
802
+ float(args.tilt_angles.split(",")[0])
803
+ except ValueError:
804
+ raise ValueError(f"{args.tilt_angles} is not a file nor a range.")
521
805
 
522
- if args.ctf_file is not None:
523
- from tme.preprocessing.tilt_series import CTF
524
-
525
- ctf = CTF.from_file(args.ctf_file)
526
- ctf.tilt_axis = args.wedge_axes[1]
527
- ctf.opening_axis = args.wedge_axes[0]
528
- template_filter.append(ctf)
529
- if isinstance(template_filter[-1], ReconstructFromTilt):
530
- template_filter.insert(-1, ctf)
531
- else:
532
- template_filter.insert(0, ctf)
533
- template_filter.isnert(1, ReconstructFromTilt())
534
-
535
- if args.lowpass or args.highpass is not None:
536
- from tme.preprocessing import BandPassFilter
537
-
538
- bandpass = BandPassFilter(
539
- use_gaussian=True,
540
- lowpass=args.lowpass,
541
- highpass=args.highpass,
542
- sampling_rate=template.sampling_rate,
543
- )
544
- template_filter.append(bandpass)
545
- target_filter.append(bandpass)
546
-
547
- if args.whiten_spectrum:
548
- from tme.preprocessing import LinearWhiteningFilter
549
-
550
- whitening_filter = LinearWhiteningFilter()
551
- template_filter.append(whitening_filter)
552
- target_filter.append(whitening_filter)
806
+ if args.ctf_file is not None and args.tilt_angles is None:
807
+ raise ValueError("Need to specify --tilt_angles when --ctf_file is set.")
553
808
 
554
- template_filter = Compose(template_filter) if len(template_filter) else None
555
- target_filter = Compose(target_filter) if len(target_filter) else None
809
+ if args.wedge_axes is not None:
810
+ args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
556
811
 
557
- return template_filter, target_filter
812
+ return args
558
813
 
559
814
 
560
815
  def main():
@@ -566,17 +821,20 @@ def main():
566
821
  try:
567
822
  template = Density.from_file(args.template)
568
823
  except Exception:
824
+ drop = target.metadata.get("batch_dimension", ())
825
+ keep = [i not in drop for i in range(target.data.ndim)]
569
826
  template = Density.from_structure(
570
827
  filename_or_structure=args.template,
571
- sampling_rate=target.sampling_rate,
828
+ sampling_rate=target.sampling_rate[keep],
572
829
  )
573
830
 
574
- if not np.allclose(target.sampling_rate, template.sampling_rate):
575
- print(
576
- f"Resampling template to {target.sampling_rate}. "
577
- "Consider providing a template with the same sampling rate as the target."
578
- )
579
- template = template.resample(target.sampling_rate, order=3)
831
+ if target.sampling_rate.size == template.sampling_rate.size:
832
+ if not np.allclose(target.sampling_rate, template.sampling_rate):
833
+ print(
834
+ f"Resampling template to {target.sampling_rate}. "
835
+ "Consider providing a template with the same sampling rate as the target."
836
+ )
837
+ template = template.resample(target.sampling_rate, order=3)
580
838
 
581
839
  template_mask = load_and_validate_mask(
582
840
  mask_target=template, mask_path=args.template_mask
@@ -709,31 +967,52 @@ def main():
709
967
  if args.memory is None:
710
968
  args.memory = int(args.memory_scaling * available_memory)
711
969
 
712
- target_padding = np.zeros_like(template.shape)
713
- if args.pad_target_edges:
714
- target_padding = template.shape
970
+ callback_class = MaxScoreOverRotations
971
+ if args.peak_calling:
972
+ callback_class = PeakCallerMaximumFilter
973
+
974
+ matching_data = MatchingData(
975
+ target=target,
976
+ template=template.data,
977
+ target_mask=target_mask,
978
+ template_mask=template_mask,
979
+ invert_target=args.invert_target_contrast,
980
+ rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
981
+ )
715
982
 
716
- template_box = template.shape
983
+ template_filter, target_filter = setup_filter(args, template, target)
984
+ matching_data.template_filter = template_filter
985
+ matching_data.target_filter = target_filter
986
+
987
+ target_dims = target.metadata.get("batch_dimension", None)
988
+ matching_data._set_batch_dimension(target_dims=target_dims, template_dims=None)
989
+ args.score = "FLC2" if target_dims is not None else args.score
990
+ args.target_batch, args.template_batch = target_dims, None
991
+
992
+ template_box = matching_data._output_template_shape
717
993
  if not args.pad_fourier:
718
994
  template_box = np.ones(len(template_box), dtype=int)
719
995
 
720
- callback_class = MaxScoreOverRotations
721
- if args.peak_calling:
722
- callback_class = PeakCallerMaximumFilter
996
+ target_padding = np.zeros(
997
+ (backend.size(matching_data._output_template_shape)), dtype=int
998
+ )
999
+ if args.pad_target_edges:
1000
+ target_padding = matching_data._output_template_shape
723
1001
 
724
1002
  splits, schedule = compute_parallelization_schedule(
725
1003
  shape1=target.shape,
726
- shape2=template_box,
727
- shape1_padding=target_padding,
1004
+ shape2=tuple(int(x) for x in template_box),
1005
+ shape1_padding=tuple(int(x) for x in target_padding),
728
1006
  max_cores=args.cores,
729
1007
  max_ram=args.memory,
730
1008
  split_only_outer=args.use_gpu,
731
1009
  matching_method=args.score,
732
1010
  analyzer_method=callback_class.__name__,
733
1011
  backend=backend._backend_name,
734
- float_nbytes=backend.datatype_bytes(backend._default_dtype),
1012
+ float_nbytes=backend.datatype_bytes(backend._float_dtype),
735
1013
  complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
736
- integer_nbytes=backend.datatype_bytes(backend._default_dtype_int),
1014
+ integer_nbytes=backend.datatype_bytes(backend._int_dtype),
1015
+ split_axes=target_dims,
737
1016
  )
738
1017
 
739
1018
  if splits is None:
@@ -743,32 +1022,7 @@ def main():
743
1022
  )
744
1023
  exit(-1)
745
1024
 
746
- analyzer_args = {
747
- "score_threshold": args.score_threshold,
748
- "number_of_peaks": 1000,
749
- "convolution_mode": "valid",
750
- "use_memmap": args.use_memmap,
751
- }
752
-
753
1025
  matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
754
- matching_data = MatchingData(target=target, template=template.data)
755
- matching_data.rotations = get_rotation_matrices(
756
- angular_sampling=args.angular_sampling, dim=target.data.ndim
757
- )
758
- if args.angular_sampling >= 180:
759
- ndim = target.data.ndim
760
- matching_data.rotations = np.eye(ndim).reshape(1, ndim, ndim)
761
-
762
- template_filter, target_filter = setup_filter(args, template, target)
763
- matching_data.template_filter = template_filter
764
- matching_data.target_filter = target_filter
765
-
766
- matching_data._invert_target = args.invert_target_contrast
767
- if target_mask is not None:
768
- matching_data.target_mask = target_mask
769
- if template_mask is not None:
770
- matching_data.template_mask = template_mask.data
771
-
772
1026
  n_splits = np.prod(list(splits.values()))
773
1027
  target_split = ", ".join(
774
1028
  [":".join([str(x) for x in axis]) for axis in splits.items()]
@@ -798,10 +1052,45 @@ def main():
798
1052
  label_width=max(len(key) for key in options.keys()) + 2,
799
1053
  )
800
1054
 
801
- options = {"Analyzer": callback_class, **analyzer_args}
1055
+ filter_args = {
1056
+ "Lowpass": args.lowpass,
1057
+ "Highpass": args.highpass,
1058
+ "Smooth Pass": args.no_pass_smooth,
1059
+ "Pass Format": args.pass_format,
1060
+ "Spectral Whitening": args.whiten_spectrum,
1061
+ "Wedge Axes": args.wedge_axes,
1062
+ "Tilt Angles": args.tilt_angles,
1063
+ "Tilt Weighting": args.tilt_weighting,
1064
+ "Reconstruction Filter": args.reconstruction_filter,
1065
+ }
1066
+ if args.ctf_file is not None or args.defocus is not None:
1067
+ filter_args["CTF File"] = args.ctf_file
1068
+ filter_args["Defocus"] = args.defocus
1069
+ filter_args["Phase Shift"] = args.phase_shift
1070
+ filter_args["No Flip Phase"] = args.no_flip_phase
1071
+ filter_args["Acceleration Voltage"] = args.acceleration_voltage
1072
+ filter_args["Spherical Aberration"] = args.spherical_aberration
1073
+ filter_args["Amplitude Contrast"] = args.amplitude_contrast
1074
+ filter_args["Correct Defocus"] = args.correct_defocus_gradient
1075
+
1076
+ filter_args = {k: v for k, v in filter_args.items() if v is not None}
1077
+ if len(filter_args):
1078
+ print_block(
1079
+ name="Filters",
1080
+ data=filter_args,
1081
+ label_width=max(len(key) for key in options.keys()) + 2,
1082
+ )
1083
+
1084
+ analyzer_args = {
1085
+ "score_threshold": args.score_threshold,
1086
+ "number_of_peaks": 1000,
1087
+ "convolution_mode": "valid",
1088
+ "use_memmap": args.use_memmap,
1089
+ }
1090
+ analyzer_args = {"Analyzer": callback_class, **analyzer_args}
802
1091
  print_block(
803
1092
  name="Score Analysis Options",
804
- data=options,
1093
+ data=analyzer_args,
805
1094
  label_width=max(len(key) for key in options.keys()) + 2,
806
1095
  )
807
1096
  print("\n" + "-" * 80)
@@ -832,16 +1121,16 @@ def main():
832
1121
  candidates[0] *= target_mask.data
833
1122
  with warnings.catch_warnings():
834
1123
  warnings.simplefilter("ignore", category=UserWarning)
1124
+ nbytes = backend.datatype_bytes(backend._float_dtype)
1125
+ dtype = np.float32 if nbytes == 4 else np.float16
1126
+ rot_dim = matching_data.rotations.shape[1]
835
1127
  candidates[3] = {
836
1128
  x: euler_from_rotationmatrix(
837
- np.frombuffer(i, dtype=matching_data.rotations.dtype).reshape(
838
- candidates[0].ndim, candidates[0].ndim
839
- )
1129
+ np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
840
1130
  )
841
1131
  for i, x in candidates[3].items()
842
1132
  }
843
-
844
- candidates.append((target.origin, template.origin, target.sampling_rate, args))
1133
+ candidates.append((target.origin, template.origin, template.sampling_rate, args))
845
1134
  write_pickle(data=candidates, filename=args.output)
846
1135
 
847
1136
  runtime = time() - start