pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. pytme-0.2.2.data/scripts/match_template.py +1187 -0
  2. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
  3. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +126 -87
  8. scripts/match_template.py +596 -209
  9. scripts/match_template_filters.py +571 -223
  10. scripts/postprocess.py +170 -71
  11. scripts/preprocessor_gui.py +179 -86
  12. scripts/refine_matches.py +567 -159
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +627 -855
  16. tme/backends/__init__.py +41 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +120 -225
  19. tme/backends/jax_backend.py +282 -0
  20. tme/backends/matching_backend.py +464 -388
  21. tme/backends/mlx_backend.py +45 -68
  22. tme/backends/npfftw_backend.py +256 -514
  23. tme/backends/pytorch_backend.py +41 -154
  24. tme/density.py +312 -421
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +366 -303
  27. tme/matching_exhaustive.py +279 -1521
  28. tme/matching_optimization.py +234 -129
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +281 -387
  31. tme/memory.py +377 -0
  32. tme/orientations.py +226 -66
  33. tme/parser.py +3 -4
  34. tme/preprocessing/__init__.py +2 -0
  35. tme/preprocessing/_utils.py +217 -0
  36. tme/preprocessing/composable_filter.py +31 -0
  37. tme/preprocessing/compose.py +55 -0
  38. tme/preprocessing/frequency_filters.py +388 -0
  39. tme/preprocessing/tilt_series.py +1011 -0
  40. tme/preprocessor.py +574 -530
  41. tme/structure.py +495 -189
  42. tme/types.py +5 -3
  43. pytme-0.2.0b0.data/scripts/match_template.py +0 -800
  44. pytme-0.2.0b0.dist-info/METADATA +0 -73
  45. pytme-0.2.0b0.dist-info/RECORD +0 -66
  46. tme/helpers.py +0 -881
  47. tme/matching_constrained.py +0 -195
  48. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  49. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  50. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
@@ -8,20 +8,19 @@
8
8
  import os
9
9
  import argparse
10
10
  import warnings
11
- import importlib.util
12
11
  from sys import exit
13
12
  from time import time
14
13
  from typing import Tuple
15
14
  from copy import deepcopy
16
- from os.path import abspath
15
+ from os.path import abspath, exists
17
16
 
18
17
  import numpy as np
19
18
 
20
19
  from tme import Density, __version__
21
20
  from tme.matching_utils import (
22
21
  get_rotation_matrices,
22
+ get_rotations_around_vector,
23
23
  compute_parallelization_schedule,
24
- euler_from_rotationmatrix,
25
24
  scramble_phases,
26
25
  generate_tempfile_name,
27
26
  write_pickle,
@@ -32,9 +31,9 @@ from tme.analyzer import (
32
31
  MaxScoreOverRotations,
33
32
  PeakCallerMaximumFilter,
34
33
  )
34
+ from tme.backends import backend as be
35
35
  from tme.preprocessing import Compose
36
- from tme.backends import backend
37
-
36
+ from tme.scoring import flc_scoring2
38
37
 
39
38
  def get_func_fullname(func) -> str:
40
39
  """Returns the full name of the given function, including its module."""
@@ -51,7 +50,7 @@ def print_block(name: str, data: dict, label_width=20) -> None:
51
50
 
52
51
  def print_entry() -> None:
53
52
  width = 80
54
- text = f" pyTME v{__version__} "
53
+ text = f" pytme v{__version__} "
55
54
  padding_total = width - len(text) - 2
56
55
  padding_left = padding_total // 2
57
56
  padding_right = padding_total - padding_left
@@ -152,8 +151,200 @@ def crop_data(data: Density, cutoff: float, data_mask: Density = None) -> bool:
152
151
  return True
153
152
 
154
153
 
154
+ def parse_rotation_logic(args, ndim):
155
+ if args.angular_sampling is not None:
156
+ rotations = get_rotation_matrices(
157
+ angular_sampling=args.angular_sampling,
158
+ dim=ndim,
159
+ use_optimized_set=not args.no_use_optimized_set,
160
+ )
161
+ if args.angular_sampling >= 180:
162
+ rotations = np.eye(ndim).reshape(1, ndim, ndim)
163
+ return rotations
164
+
165
+ if args.axis_sampling is None:
166
+ args.axis_sampling = args.cone_sampling
167
+
168
+ rotations = get_rotations_around_vector(
169
+ cone_angle=args.cone_angle,
170
+ cone_sampling=args.cone_sampling,
171
+ axis_angle=args.axis_angle,
172
+ axis_sampling=args.axis_sampling,
173
+ n_symmetry=args.axis_symmetry,
174
+ )
175
+ return rotations
176
+
177
+
178
+ # TODO: Think about whether wedge mask should also be added to target
179
+ # For now leave it at the cost of incorrect upper bound on the scores
180
+ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
181
+ from tme.preprocessing import LinearWhiteningFilter, BandPassFilter
182
+ from tme.preprocessing.tilt_series import (
183
+ Wedge,
184
+ WedgeReconstructed,
185
+ ReconstructFromTilt,
186
+ )
187
+
188
+ template_filter, target_filter = [], []
189
+ if args.tilt_angles is not None:
190
+ try:
191
+ wedge = Wedge.from_file(args.tilt_angles)
192
+ wedge.weight_type = args.tilt_weighting
193
+ if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
194
+ wedge = WedgeReconstructed(
195
+ angles=wedge.angles, weight_wedge=args.tilt_weighting == "angle"
196
+ )
197
+ except FileNotFoundError:
198
+ tilt_step, create_continuous_wedge = None, True
199
+ tilt_start, tilt_stop = args.tilt_angles.split(",")
200
+ if ":" in tilt_stop:
201
+ create_continuous_wedge = False
202
+ tilt_stop, tilt_step = tilt_stop.split(":")
203
+ tilt_start, tilt_stop = float(tilt_start), float(tilt_stop)
204
+ tilt_angles = (tilt_start, tilt_stop)
205
+ if tilt_step is not None:
206
+ tilt_step = float(tilt_step)
207
+ tilt_angles = np.arange(
208
+ -tilt_start, tilt_stop + tilt_step, tilt_step
209
+ ).tolist()
210
+
211
+ if args.tilt_weighting is not None and tilt_step is None:
212
+ raise ValueError(
213
+ "Tilt weighting is not supported for continuous wedges."
214
+ )
215
+ if args.tilt_weighting not in ("angle", None):
216
+ raise ValueError(
217
+ "Tilt weighting schemes other than 'angle' or 'None' require "
218
+ "a specification of electron doses via --tilt_angles."
219
+ )
220
+
221
+ wedge = Wedge(
222
+ angles=tilt_angles,
223
+ opening_axis=args.wedge_axes[0],
224
+ tilt_axis=args.wedge_axes[1],
225
+ shape=None,
226
+ weight_type=None,
227
+ weights=np.ones_like(tilt_angles),
228
+ )
229
+ if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
230
+ wedge = WedgeReconstructed(
231
+ angles=tilt_angles,
232
+ weight_wedge=args.tilt_weighting == "angle",
233
+ create_continuous_wedge=create_continuous_wedge,
234
+ )
235
+
236
+ wedge.opening_axis = args.wedge_axes[0]
237
+ wedge.tilt_axis = args.wedge_axes[1]
238
+ wedge.sampling_rate = template.sampling_rate
239
+ template_filter.append(wedge)
240
+ if not isinstance(wedge, WedgeReconstructed):
241
+ template_filter.append(
242
+ ReconstructFromTilt(
243
+ reconstruction_filter=args.reconstruction_filter,
244
+ interpolation_order=args.reconstruction_interpolation_order,
245
+ )
246
+ )
247
+
248
+ if args.ctf_file is not None or args.defocus is not None:
249
+ from tme.preprocessing.tilt_series import CTF
250
+
251
+ needs_reconstruction = True
252
+ if args.ctf_file is not None:
253
+ ctf = CTF.from_file(args.ctf_file)
254
+ n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
255
+ if n_tilts_ctfs != n_tils_angles:
256
+ raise ValueError(
257
+ f"CTF file contains {n_tilts_ctfs} micrographs, but match_template "
258
+ f"recieved {n_tils_angles} tilt angles. Expected one angle "
259
+ "per micrograph."
260
+ )
261
+ ctf.angles = wedge.angles
262
+ ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
263
+ else:
264
+ needs_reconstruction = False
265
+ ctf = CTF(
266
+ defocus_x=[args.defocus],
267
+ phase_shift=[args.phase_shift],
268
+ defocus_y=None,
269
+ angles=[0],
270
+ shape=None,
271
+ return_real_fourier=True,
272
+ )
273
+ ctf.sampling_rate = template.sampling_rate
274
+ ctf.flip_phase = args.no_flip_phase
275
+ ctf.amplitude_contrast = args.amplitude_contrast
276
+ ctf.spherical_aberration = args.spherical_aberration
277
+ ctf.acceleration_voltage = args.acceleration_voltage * 1e3
278
+ ctf.correct_defocus_gradient = args.correct_defocus_gradient
279
+
280
+ if not needs_reconstruction:
281
+ template_filter.append(ctf)
282
+ elif isinstance(template_filter[-1], ReconstructFromTilt):
283
+ template_filter.insert(-1, ctf)
284
+ else:
285
+ template_filter.insert(0, ctf)
286
+ template_filter.insert(
287
+ 1,
288
+ ReconstructFromTilt(
289
+ reconstruction_filter=args.reconstruction_filter,
290
+ interpolation_order=args.reconstruction_interpolation_order,
291
+ ),
292
+ )
293
+
294
+ if args.lowpass or args.highpass is not None:
295
+ lowpass, highpass = args.lowpass, args.highpass
296
+ if args.pass_format == "voxel":
297
+ if lowpass is not None:
298
+ lowpass = np.max(np.multiply(lowpass, template.sampling_rate))
299
+ if highpass is not None:
300
+ highpass = np.max(np.multiply(highpass, template.sampling_rate))
301
+ elif args.pass_format == "frequency":
302
+ if lowpass is not None:
303
+ lowpass = np.max(np.divide(template.sampling_rate, lowpass))
304
+ if highpass is not None:
305
+ highpass = np.max(np.divide(template.sampling_rate, highpass))
306
+
307
+ try:
308
+ if args.lowpass >= args.highpass:
309
+ warnings.warn("--lowpass should be smaller than --highpass.")
310
+ except Exception:
311
+ pass
312
+
313
+ bandpass = BandPassFilter(
314
+ use_gaussian=args.no_pass_smooth,
315
+ lowpass=lowpass,
316
+ highpass=highpass,
317
+ sampling_rate=template.sampling_rate,
318
+ )
319
+ template_filter.append(bandpass)
320
+
321
+ if not args.no_filter_target:
322
+ target_filter.append(bandpass)
323
+
324
+ if args.whiten_spectrum:
325
+ whitening_filter = LinearWhiteningFilter()
326
+ template_filter.append(whitening_filter)
327
+ target_filter.append(whitening_filter)
328
+
329
+ needs_reconstruction = any(
330
+ [isinstance(t, ReconstructFromTilt) for t in template_filter]
331
+ )
332
+ if needs_reconstruction and args.reconstruction_filter is None:
333
+ warnings.warn(
334
+ "Consider using a --reconstruction_filter such as 'ramp' to avoid artifacts."
335
+ )
336
+
337
+ template_filter = Compose(template_filter) if len(template_filter) else None
338
+ target_filter = Compose(target_filter) if len(target_filter) else None
339
+
340
+ return template_filter, target_filter
341
+
342
+
155
343
  def parse_args():
156
- parser = argparse.ArgumentParser(description="Perform template matching.")
344
+ parser = argparse.ArgumentParser(
345
+ description="Perform template matching.",
346
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
347
+ )
157
348
 
158
349
  io_group = parser.add_argument_group("Input / Output")
159
350
  io_group.add_argument(
@@ -223,14 +414,65 @@ def parse_args():
223
414
  choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
224
415
  help="Template matching scoring function.",
225
416
  )
226
- scoring_group.add_argument(
417
+
418
+ angular_group = parser.add_argument_group("Angular Sampling")
419
+ angular_exclusive = angular_group.add_mutually_exclusive_group(required=True)
420
+
421
+ angular_exclusive.add_argument(
227
422
  "-a",
228
423
  dest="angular_sampling",
229
424
  type=check_positive,
230
- default=40.0,
231
- help="Angular sampling rate for template matching. "
425
+ default=None,
426
+ help="Angular sampling rate using optimized rotational sets."
232
427
  "A lower number yields more rotations. Values >= 180 sample only the identity.",
233
428
  )
429
+ angular_exclusive.add_argument(
430
+ "--cone_angle",
431
+ dest="cone_angle",
432
+ type=check_positive,
433
+ default=None,
434
+ help="Half-angle of the cone to be sampled in degrees. Allows to sample a "
435
+ "narrow interval around a known orientation, e.g. for surface oversampling.",
436
+ )
437
+ angular_group.add_argument(
438
+ "--cone_sampling",
439
+ dest="cone_sampling",
440
+ type=check_positive,
441
+ default=None,
442
+ help="Sampling rate of the cone in degrees.",
443
+ )
444
+ angular_group.add_argument(
445
+ "--axis_angle",
446
+ dest="axis_angle",
447
+ type=check_positive,
448
+ default=360.0,
449
+ required=False,
450
+ help="Sampling angle along the z-axis of the cone.",
451
+ )
452
+ angular_group.add_argument(
453
+ "--axis_sampling",
454
+ dest="axis_sampling",
455
+ type=check_positive,
456
+ default=None,
457
+ required=False,
458
+ help="Sampling rate along the z-axis of the cone. Defaults to --cone_sampling.",
459
+ )
460
+ angular_group.add_argument(
461
+ "--axis_symmetry",
462
+ dest="axis_symmetry",
463
+ type=check_positive,
464
+ default=1,
465
+ required=False,
466
+ help="N-fold symmetry around z-axis of the cone.",
467
+ )
468
+ angular_group.add_argument(
469
+ "--no_use_optimized_set",
470
+ dest="no_use_optimized_set",
471
+ action="store_true",
472
+ default=False,
473
+ required=False,
474
+ help="Whether to use random uniform instead of optimized rotation sets.",
475
+ )
234
476
 
235
477
  computation_group = parser.add_argument_group("Computation")
236
478
  computation_group.add_argument(
@@ -273,23 +515,7 @@ def parse_args():
273
515
  required=False,
274
516
  type=float,
275
517
  default=0.85,
276
- help="Fraction of available memory that can be used. Defaults to 0.85 and is "
277
- "ignored if --ram is set",
278
- )
279
- computation_group.add_argument(
280
- "--use_mixed_precision",
281
- dest="use_mixed_precision",
282
- action="store_true",
283
- default=False,
284
- help="Use float16 for real values operations where possible.",
285
- )
286
- computation_group.add_argument(
287
- "--use_memmap",
288
- dest="use_memmap",
289
- action="store_true",
290
- default=False,
291
- help="Use memmaps to offload large data objects to disk. "
292
- "Particularly useful for large inputs in combination with --use_gpu.",
518
+ help="Fraction of available memory to be used. Ignored if --ram is set."
293
519
  )
294
520
  computation_group.add_argument(
295
521
  "--temp_directory",
@@ -297,6 +523,13 @@ def parse_args():
297
523
  default=None,
298
524
  help="Directory for temporary objects. Faster I/O improves runtime.",
299
525
  )
526
+ computation_group.add_argument(
527
+ "--backend",
528
+ dest="backend",
529
+ default=None,
530
+ choices=be.available_backends(),
531
+ help="[Expert] Overwrite default computation backend.",
532
+ )
300
533
 
301
534
  filter_group = parser.add_argument_group("Filters")
302
535
  filter_group.add_argument(
@@ -315,11 +548,27 @@ def parse_args():
315
548
  help="Resolution to highpass filter template and target to in the same unit "
316
549
  "as the sampling rate of template and target (typically Ångstrom).",
317
550
  )
551
+ filter_group.add_argument(
552
+ "--no_pass_smooth",
553
+ dest="no_pass_smooth",
554
+ action="store_false",
555
+ default=True,
556
+ help="Whether a hard edge filter should be used for --lowpass and --highpass.",
557
+ )
558
+ filter_group.add_argument(
559
+ "--pass_format",
560
+ dest="pass_format",
561
+ type=str,
562
+ required=False,
563
+ default="sampling_rate",
564
+ choices=["sampling_rate", "voxel", "frequency"],
565
+ help="How values passed to --lowpass and --highpass should be interpreted. ",
566
+ )
318
567
  filter_group.add_argument(
319
568
  "--whiten_spectrum",
320
569
  dest="whiten_spectrum",
321
570
  action="store_true",
322
- default=False,
571
+ default=None,
323
572
  help="Apply spectral whitening to template and target based on target spectrum.",
324
573
  )
325
574
  filter_group.add_argument(
@@ -327,7 +576,7 @@ def parse_args():
327
576
  dest="wedge_axes",
328
577
  type=str,
329
578
  required=False,
330
- default="0,2",
579
+ default=None,
331
580
  help="Indices of wedge opening and tilt axis, e.g. 0,2 for a wedge that is open "
332
581
  "in z-direction and tilted over the x axis.",
333
582
  )
@@ -337,10 +586,10 @@ def parse_args():
337
586
  type=str,
338
587
  required=False,
339
588
  default=None,
340
- help="Path to a file with angles and corresponding doses, or comma separated "
341
- "start and stop stage tilt angle, e.g. 50,45, which yields a continuous wedge "
342
- "mask. Alternatively, a tilt step size can be specified like 50,45:5.0 to "
343
- "sample 5.0 degree tilt angle steps.",
589
+ help="Path to a tab-separated file containing the column angles and optionally "
590
+ " weights, or comma separated start and stop stage tilt angle, e.g. 50,45, which "
591
+ " yields a continuous wedge mask. Alternatively, a tilt step size can be "
592
+ "specified like 50,45:5.0 to sample 5.0 degree tilt angle steps.",
344
593
  )
345
594
  filter_group.add_argument(
346
595
  "--tilt_weighting",
@@ -351,17 +600,100 @@ def parse_args():
351
600
  default=None,
352
601
  help="Weighting scheme used to reweight individual tilts. Available options: "
353
602
  "angle (cosine based weighting), "
354
- "relion (relion formalism for wedge weighting ),"
603
+ "relion (relion formalism for wedge weighting) requires,"
355
604
  "grigorieff (exposure filter as defined in Grant and Grigorieff 2015)."
356
- "",
605
+ "relion and grigorieff require electron doses in --tilt_angles weights column.",
606
+ )
607
+ filter_group.add_argument(
608
+ "--reconstruction_filter",
609
+ dest="reconstruction_filter",
610
+ type=str,
611
+ required=False,
612
+ choices=["ram-lak", "ramp", "ramp-cont", "shepp-logan", "cosine", "hamming"],
613
+ default=None,
614
+ help="Filter applied when reconstructing (N+1)-D from N-D filters.",
615
+ )
616
+ filter_group.add_argument(
617
+ "--reconstruction_interpolation_order",
618
+ dest="reconstruction_interpolation_order",
619
+ type=int,
620
+ default=1,
621
+ required=False,
622
+ help="Analogous to --interpolation_order but for reconstruction.",
357
623
  )
358
624
  filter_group.add_argument(
625
+ "--no_filter_target",
626
+ dest="no_filter_target",
627
+ action="store_true",
628
+ default=False,
629
+ help="Whether to not apply potential filters to the target.",
630
+ )
631
+
632
+ ctf_group = parser.add_argument_group("Contrast Transfer Function")
633
+ ctf_group.add_argument(
359
634
  "--ctf_file",
360
635
  dest="ctf_file",
361
636
  type=str,
362
637
  required=False,
363
638
  default=None,
364
- help="Path to a file with CTF parameters.",
639
+ help="Path to a file with CTF parameters from CTFFIND4. Each line will be "
640
+ "interpreted as tilt obtained at the angle specified in --tilt_angles. ",
641
+ )
642
+ ctf_group.add_argument(
643
+ "--defocus",
644
+ dest="defocus",
645
+ type=float,
646
+ required=False,
647
+ default=None,
648
+ help="Defocus in units of sampling rate (typically Ångstrom). "
649
+ "Superseded by --ctf_file.",
650
+ )
651
+ ctf_group.add_argument(
652
+ "--phase_shift",
653
+ dest="phase_shift",
654
+ type=float,
655
+ required=False,
656
+ default=0,
657
+ help="Phase shift in degrees. Superseded by --ctf_file.",
658
+ )
659
+ ctf_group.add_argument(
660
+ "--acceleration_voltage",
661
+ dest="acceleration_voltage",
662
+ type=float,
663
+ required=False,
664
+ default=300,
665
+ help="Acceleration voltage in kV.",
666
+ )
667
+ ctf_group.add_argument(
668
+ "--spherical_aberration",
669
+ dest="spherical_aberration",
670
+ type=float,
671
+ required=False,
672
+ default=2.7e7,
673
+ help="Spherical aberration in units of sampling rate (typically Ångstrom).",
674
+ )
675
+ ctf_group.add_argument(
676
+ "--amplitude_contrast",
677
+ dest="amplitude_contrast",
678
+ type=float,
679
+ required=False,
680
+ default=0.07,
681
+ help="Amplitude contrast.",
682
+ )
683
+ ctf_group.add_argument(
684
+ "--no_flip_phase",
685
+ dest="no_flip_phase",
686
+ action="store_false",
687
+ required=False,
688
+ help="Perform phase-flipping CTF correction.",
689
+ )
690
+ ctf_group.add_argument(
691
+ "--correct_defocus_gradient",
692
+ dest="correct_defocus_gradient",
693
+ action="store_true",
694
+ required=False,
695
+ help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
696
+ "defocus gradients.",
365
697
  )
366
698
 
367
699
  performance_group = parser.add_argument_group("Performance")
@@ -404,14 +736,37 @@ def parse_args():
404
736
  "for numerical stability. When working with very large targets, e.g. tomograms, "
405
737
  "it is safe to use this flag and benefit from the performance gain.",
406
738
  )
739
+ performance_group.add_argument(
740
+ "--no_filter_padding",
741
+ dest="no_filter_padding",
742
+ action="store_true",
743
+ default=False,
744
+ help="Omits padding of optional template filters. Particularly effective when "
745
+ "the target is much larger than the template. However, for fast osciliating "
746
+ "filters setting this flag can introduce aliasing effects.",
747
+ )
407
748
  performance_group.add_argument(
408
749
  "--interpolation_order",
409
750
  dest="interpolation_order",
410
751
  required=False,
411
752
  type=int,
412
753
  default=3,
413
- help="Spline interpolation used for template rotations. If less than zero "
414
- "no interpolation is performed.",
754
+ help="Spline interpolation used for rotations.",
755
+ )
756
+ performance_group.add_argument(
757
+ "--use_mixed_precision",
758
+ dest="use_mixed_precision",
759
+ action="store_true",
760
+ default=False,
761
+ help="Use float16 for real values operations where possible.",
762
+ )
763
+ performance_group.add_argument(
764
+ "--use_memmap",
765
+ dest="use_memmap",
766
+ action="store_true",
767
+ default=False,
768
+ help="Use memmaps to offload large data objects to disk. "
769
+ "Particularly useful for large inputs in combination with --use_gpu.",
415
770
  )
416
771
 
417
772
  analyzer_group = parser.add_argument_group("Analyzer")
@@ -430,7 +785,15 @@ def parse_args():
430
785
  default=False,
431
786
  help="Perform peak calling instead of score aggregation.",
432
787
  )
788
+ analyzer_group.add_argument(
789
+ "--number_of_peaks",
790
+ dest="number_of_peaks",
791
+ action="store_true",
792
+ default=1000,
793
+ help="Number of peaks to call, 1000 by default..",
794
+ )
433
795
  args = parser.parse_args()
796
+ args.version = __version__
434
797
 
435
798
  if args.interpolation_order < 0:
436
799
  args.interpolation_order = None
@@ -467,94 +830,22 @@ def parse_args():
467
830
  int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
468
831
  ]
469
832
 
470
- if args.wedge_axes is not None:
471
- args.wedge_axes = [int(x) for x in args.wedge_axes.split(",")]
472
-
473
- if args.tilt_angles is not None and args.wedge_axes is None:
474
- raise ValueError("Wedge axes have to be specified with tilt angles.")
475
-
476
- if args.ctf_file is not None and args.wedge_axes is None:
477
- raise ValueError("Wedge axes have to be specified with CTF parameters.")
478
- if args.ctf_file is not None and args.tilt_angles is None:
479
- raise ValueError("Angles have to be specified with CTF parameters.")
480
-
481
- return args
482
-
483
-
484
- def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
485
- from tme.preprocessing import LinearWhiteningFilter, BandPassFilter
486
-
487
- template_filter, target_filter = [], []
488
833
  if args.tilt_angles is not None:
489
- from tme.preprocessing.tilt_series import (
490
- Wedge,
491
- WedgeReconstructed,
492
- ReconstructFromTilt,
493
- )
494
-
495
- try:
496
- wedge = Wedge.from_file(args.tilt_angles)
497
- wedge.weight_type = args.tilt_weighting
498
- except FileNotFoundError:
499
- tilt_step = None
500
- tilt_start, tilt_stop = args.tilt_angles.split(",")
501
- if ":" in tilt_stop:
502
- tilt_stop, tilt_step = tilt_stop.split(":")
503
- tilt_start, tilt_stop = float(tilt_start), float(tilt_stop)
504
- tilt_angles = None
505
- if tilt_step is not None:
506
- tilt_step = float(tilt_step)
507
- tilt_angles = np.arange(
508
- -tilt_start, tilt_stop + tilt_step, tilt_step
509
- ).tolist()
510
- wedge = WedgeReconstructed(
511
- angles=tilt_angles,
512
- start_tilt=tilt_start,
513
- stop_tilt=tilt_stop,
514
- )
515
- wedge.opening_axis = args.wedge_axes[0]
516
- wedge.tilt_axis = args.wedge_axes[1]
517
- wedge.sampling_rate = template.sampling_rate
518
- template_filter.append(wedge)
519
- if not isinstance(wedge, WedgeReconstructed):
520
- template_filter.append(ReconstructFromTilt())
521
-
522
- if args.ctf_file is not None:
523
- from tme.preprocessing.tilt_series import CTF
524
-
525
- ctf = CTF.from_file(args.ctf_file)
526
- ctf.tilt_axis = args.wedge_axes[1]
527
- ctf.opening_axis = args.wedge_axes[0]
528
- template_filter.append(ctf)
529
- if isinstance(template_filter[-1], ReconstructFromTilt):
530
- template_filter.insert(-1, ctf)
531
- else:
532
- template_filter.insert(0, ctf)
533
- template_filter.isnert(1, ReconstructFromTilt())
534
-
535
- if args.lowpass or args.highpass is not None:
536
- from tme.preprocessing import BandPassFilter
537
-
538
- bandpass = BandPassFilter(
539
- use_gaussian=True,
540
- lowpass=args.lowpass,
541
- highpass=args.highpass,
542
- sampling_rate=template.sampling_rate,
543
- )
544
- template_filter.append(bandpass)
545
- target_filter.append(bandpass)
546
-
547
- if args.whiten_spectrum:
548
- from tme.preprocessing import LinearWhiteningFilter
834
+ if args.wedge_axes is None:
835
+ raise ValueError("Need to specify --wedge_axes when --tilt_angles is set.")
836
+ if not exists(args.tilt_angles):
837
+ try:
838
+ float(args.tilt_angles.split(",")[0])
839
+ except ValueError:
840
+ raise ValueError(f"{args.tilt_angles} is not a file nor a range.")
549
841
 
550
- whitening_filter = LinearWhiteningFilter()
551
- template_filter.append(whitening_filter)
552
- target_filter.append(whitening_filter)
842
+ if args.ctf_file is not None and args.tilt_angles is None:
843
+ raise ValueError("Need to specify --tilt_angles when --ctf_file is set.")
553
844
 
554
- template_filter = Compose(template_filter) if len(template_filter) else None
555
- target_filter = Compose(target_filter) if len(target_filter) else None
845
+ if args.wedge_axes is not None:
846
+ args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
556
847
 
557
- return template_filter, target_filter
848
+ return args
558
849
 
559
850
 
560
851
  def main():
@@ -566,17 +857,20 @@ def main():
566
857
  try:
567
858
  template = Density.from_file(args.template)
568
859
  except Exception:
860
+ drop = target.metadata.get("batch_dimension", ())
861
+ keep = [i not in drop for i in range(target.data.ndim)]
569
862
  template = Density.from_structure(
570
863
  filename_or_structure=args.template,
571
- sampling_rate=target.sampling_rate,
864
+ sampling_rate=target.sampling_rate[keep],
572
865
  )
573
866
 
574
- if not np.allclose(target.sampling_rate, template.sampling_rate):
575
- print(
576
- f"Resampling template to {target.sampling_rate}. "
577
- "Consider providing a template with the same sampling rate as the target."
578
- )
579
- template = template.resample(target.sampling_rate, order=3)
867
+ if target.sampling_rate.size == template.sampling_rate.size:
868
+ if not np.allclose(target.sampling_rate, template.sampling_rate):
869
+ print(
870
+ f"Resampling template to {target.sampling_rate}. "
871
+ "Consider providing a template with the same sampling rate as the target."
872
+ )
873
+ template = template.resample(target.sampling_rate, order=3)
580
874
 
581
875
  template_mask = load_and_validate_mask(
582
876
  mask_target=template, mask_path=args.template_mask
@@ -668,72 +962,104 @@ def main():
668
962
  template.data, noise_proportion=1.0, normalize_power=True
669
963
  )
670
964
 
671
- available_memory = backend.get_available_memory()
672
- if args.use_gpu:
673
- args.cores = len(args.gpu_indices)
674
- has_torch = importlib.util.find_spec("torch") is not None
675
- has_cupy = importlib.util.find_spec("cupy") is not None
676
-
677
- if not has_torch and not has_cupy:
965
+ # Determine suitable backend for the selected operation
966
+ available_backends = be.available_backends()
967
+ if args.backend is not None:
968
+ req_backend = args.backend
969
+ if req_backend not in available_backends:
678
970
  raise ValueError(
679
- "Found neither CuPy nor PyTorch installation. You need to install"
680
- " either to enable GPU support."
971
+ "Requested backend is not available."
681
972
  )
973
+ available_backends = [req_backend,]
682
974
 
683
- if args.peak_calling:
684
- preferred_backend = "pytorch"
685
- if not has_torch:
686
- preferred_backend = "cupy"
687
- backend.change_backend(backend_name=preferred_backend, device="cuda")
688
- else:
689
- preferred_backend = "cupy"
690
- if not has_cupy:
691
- preferred_backend = "pytorch"
692
- backend.change_backend(backend_name=preferred_backend, device="cuda")
693
- if args.use_mixed_precision and preferred_backend == "pytorch":
975
+ be_selection = ("numpyfftw", "pytorch", "jax", "mlx")
976
+ if args.use_gpu:
977
+ args.cores = len(args.gpu_indices)
978
+ be_selection = ("pytorch", "cupy", "jax")
979
+ if args.use_mixed_precision:
980
+ be_selection = tuple(x for x in be_selection if x in ("cupy", "numpyfftw"))
981
+
982
+ available_backends = [x for x in available_backends if x in be_selection]
983
+ if args.peak_calling:
984
+ if "jax" in available_backends:
985
+ available_backends.remove("jax")
986
+ if args.use_gpu and "pytorch" in available_backends:
987
+ available_backends = ("pytorch",)
988
+ if args.interpolation_order == 3:
694
989
  raise NotImplementedError(
695
- "pytorch backend does not yet support mixed precision."
696
- " Consider installing CuPy to enable this feature."
697
- )
698
- elif args.use_mixed_precision:
699
- backend.change_backend(
700
- backend_name="cupy",
701
- default_dtype=backend._array_backend.float16,
702
- complex_dtype=backend._array_backend.complex64,
703
- default_dtype_int=backend._array_backend.int16,
990
+ "Pytorch does not support --interpolation_order 3, 1 is supported."
704
991
  )
705
- available_memory = backend.get_available_memory() * args.cores
706
- if preferred_backend == "pytorch" and args.interpolation_order == 3:
707
- args.interpolation_order = 1
992
+ ndim = len(template.shape)
993
+ if len(target.shape) == ndim and ndim <= 3 and args.use_gpu:
994
+ available_backends = ["jax", ]
995
+
996
+ backend_preference = ("numpyfftw", "pytorch", "jax", "mlx")
997
+ if args.use_gpu:
998
+ backend_preference = ("cupy", "jax", "pytorch")
999
+ for pref in backend_preference:
1000
+ if pref not in available_backends:
1001
+ continue
1002
+ be.change_backend(pref)
1003
+ if pref == "pytorch":
1004
+ be.change_backend(pref, device = "cuda" if args.use_gpu else "cpu")
1005
+
1006
+ if args.use_mixed_precision:
1007
+ be.change_backend(
1008
+ backend_name=pref,
1009
+ default_dtype=be._array_backend.float16,
1010
+ complex_dtype=be._array_backend.complex64,
1011
+ default_dtype_int=be._array_backend.int16,
1012
+ )
1013
+ break
708
1014
 
1015
+ available_memory = be.get_available_memory() * be.device_count()
709
1016
  if args.memory is None:
710
1017
  args.memory = int(args.memory_scaling * available_memory)
711
1018
 
712
- target_padding = np.zeros_like(template.shape)
713
- if args.pad_target_edges:
714
- target_padding = template.shape
715
-
716
- template_box = template.shape
717
- if not args.pad_fourier:
718
- template_box = np.ones(len(template_box), dtype=int)
719
-
720
1019
  callback_class = MaxScoreOverRotations
721
1020
  if args.peak_calling:
722
1021
  callback_class = PeakCallerMaximumFilter
723
1022
 
1023
+ matching_data = MatchingData(
1024
+ target=target,
1025
+ template=template.data,
1026
+ target_mask=target_mask,
1027
+ template_mask=template_mask,
1028
+ invert_target=args.invert_target_contrast,
1029
+ rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
1030
+ )
1031
+
1032
+ matching_data.template_filter, matching_data.target_filter = setup_filter(
1033
+ args, template, target
1034
+ )
1035
+
1036
+ target_dims = target.metadata.get("batch_dimension", None)
1037
+ matching_data._set_matching_dimension(target_dims=target_dims, template_dims=None)
1038
+ args.score = "FLC" if target_dims is not None else args.score
1039
+ args.target_batch, args.template_batch = target_dims, None
1040
+
1041
+ template_box = matching_data._output_template_shape
1042
+ if not args.pad_fourier:
1043
+ template_box = tuple(0 for _ in range(len(template_box)))
1044
+
1045
+ target_padding = tuple(0 for _ in range(len(template_box)))
1046
+ if args.pad_target_edges:
1047
+ target_padding = matching_data._output_template_shape
1048
+
724
1049
  splits, schedule = compute_parallelization_schedule(
725
1050
  shape1=target.shape,
726
- shape2=template_box,
727
- shape1_padding=target_padding,
1051
+ shape2=tuple(int(x) for x in template_box),
1052
+ shape1_padding=tuple(int(x) for x in target_padding),
728
1053
  max_cores=args.cores,
729
1054
  max_ram=args.memory,
730
1055
  split_only_outer=args.use_gpu,
731
1056
  matching_method=args.score,
732
1057
  analyzer_method=callback_class.__name__,
733
- backend=backend._backend_name,
734
- float_nbytes=backend.datatype_bytes(backend._default_dtype),
735
- complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
736
- integer_nbytes=backend.datatype_bytes(backend._default_dtype_int),
1058
+ backend=be._backend_name,
1059
+ float_nbytes=be.datatype_bytes(be._float_dtype),
1060
+ complex_nbytes=be.datatype_bytes(be._complex_dtype),
1061
+ integer_nbytes=be.datatype_bytes(be._int_dtype),
1062
+ split_axes=target_dims,
737
1063
  )
738
1064
 
739
1065
  if splits is None:
@@ -743,66 +1069,87 @@ def main():
743
1069
  )
744
1070
  exit(-1)
745
1071
 
746
- analyzer_args = {
747
- "score_threshold": args.score_threshold,
748
- "number_of_peaks": 1000,
749
- "convolution_mode": "valid",
750
- "use_memmap": args.use_memmap,
751
- }
752
-
753
1072
  matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
754
- matching_data = MatchingData(target=target, template=template.data)
755
- matching_data.rotations = get_rotation_matrices(
756
- angular_sampling=args.angular_sampling, dim=target.data.ndim
757
- )
758
- if args.angular_sampling >= 180:
759
- ndim = target.data.ndim
760
- matching_data.rotations = np.eye(ndim).reshape(1, ndim, ndim)
761
-
762
- template_filter, target_filter = setup_filter(args, template, target)
763
- matching_data.template_filter = template_filter
764
- matching_data.target_filter = target_filter
765
-
766
- matching_data._invert_target = args.invert_target_contrast
767
- if target_mask is not None:
768
- matching_data.target_mask = target_mask
769
- if template_mask is not None:
770
- matching_data.template_mask = template_mask.data
771
-
1073
+ if target_dims is not None:
1074
+ matching_score = flc_scoring2
772
1075
  n_splits = np.prod(list(splits.values()))
773
1076
  target_split = ", ".join(
774
1077
  [":".join([str(x) for x in axis]) for axis in splits.items()]
775
1078
  )
776
1079
  gpus_used = 0 if args.gpu_indices is None else len(args.gpu_indices)
777
1080
  options = {
778
- "CPU Cores": args.cores,
779
- "Run on GPU": f"{args.use_gpu} [N={gpus_used}]",
780
- "Use Mixed Precision": args.use_mixed_precision,
781
- "Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
782
- "Temporary Directory": args.temp_directory,
1081
+ "Angular Sampling": f"{args.angular_sampling}"
1082
+ f" [{matching_data.rotations.shape[0]} rotations]",
1083
+ "Center Template": not args.no_centering,
1084
+ "Scramble Template": args.scramble_phases,
1085
+ "Invert Contrast": args.invert_target_contrast,
783
1086
  "Extend Fourier Grid": not args.no_fourier_padding,
784
1087
  "Extend Target Edges": not args.no_edge_padding,
785
1088
  "Interpolation Order": args.interpolation_order,
786
- "Score": f"{args.score}",
787
1089
  "Setup Function": f"{get_func_fullname(matching_setup)}",
788
1090
  "Scoring Function": f"{get_func_fullname(matching_score)}",
789
- "Angular Sampling": f"{args.angular_sampling}"
790
- f" [{matching_data.rotations.shape[0]} rotations]",
791
- "Scramble Template": args.scramble_phases,
792
- "Target Splits": f"{target_split} [N={n_splits}]",
793
1091
  }
794
1092
 
795
1093
  print_block(
796
- name="Template Matching Options",
1094
+ name="Template Matching",
797
1095
  data=options,
798
- label_width=max(len(key) for key in options.keys()) + 2,
1096
+ label_width=max(len(key) for key in options.keys()) + 3,
1097
+ )
1098
+
1099
+ compute_options = {
1100
+ "Backend" :be._BACKEND_REGISTRY[be._backend_name],
1101
+ "Compute Devices" : f"CPU [{args.cores}], GPU [{gpus_used}]",
1102
+ "Use Mixed Precision": args.use_mixed_precision,
1103
+ "Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
1104
+ "Temporary Directory": args.temp_directory,
1105
+ "Target Splits": f"{target_split} [N={n_splits}]",
1106
+ }
1107
+ print_block(
1108
+ name="Computation",
1109
+ data=compute_options,
1110
+ label_width=max(len(key) for key in options.keys()) + 3,
799
1111
  )
800
1112
 
801
- options = {"Analyzer": callback_class, **analyzer_args}
1113
+ filter_args = {
1114
+ "Lowpass": args.lowpass,
1115
+ "Highpass": args.highpass,
1116
+ "Smooth Pass": args.no_pass_smooth,
1117
+ "Pass Format": args.pass_format,
1118
+ "Spectral Whitening": args.whiten_spectrum,
1119
+ "Wedge Axes": args.wedge_axes,
1120
+ "Tilt Angles": args.tilt_angles,
1121
+ "Tilt Weighting": args.tilt_weighting,
1122
+ "Reconstruction Filter": args.reconstruction_filter,
1123
+ }
1124
+ if args.ctf_file is not None or args.defocus is not None:
1125
+ filter_args["CTF File"] = args.ctf_file
1126
+ filter_args["Defocus"] = args.defocus
1127
+ filter_args["Phase Shift"] = args.phase_shift
1128
+ filter_args["Flip Phase"] = args.no_flip_phase
1129
+ filter_args["Acceleration Voltage"] = args.acceleration_voltage
1130
+ filter_args["Spherical Aberration"] = args.spherical_aberration
1131
+ filter_args["Amplitude Contrast"] = args.amplitude_contrast
1132
+ filter_args["Correct Defocus"] = args.correct_defocus_gradient
1133
+
1134
+ filter_args = {k: v for k, v in filter_args.items() if v is not None}
1135
+ if len(filter_args):
1136
+ print_block(
1137
+ name="Filters",
1138
+ data=filter_args,
1139
+ label_width=max(len(key) for key in options.keys()) + 3,
1140
+ )
1141
+
1142
+ analyzer_args = {
1143
+ "score_threshold": args.score_threshold,
1144
+ "number_of_peaks": args.number_of_peaks,
1145
+ "min_distance" : max(template.shape) // 2,
1146
+ "min_boundary_distance" : max(template.shape) // 2,
1147
+ "use_memmap": args.use_memmap,
1148
+ }
802
1149
  print_block(
803
- name="Score Analysis Options",
804
- data=options,
805
- label_width=max(len(key) for key in options.keys()) + 2,
1150
+ name="Analyzer",
1151
+ data={"Analyzer": callback_class, **analyzer_args},
1152
+ label_width=max(len(key) for key in options.keys()) + 3,
806
1153
  )
807
1154
  print("\n" + "-" * 80)
808
1155
 
@@ -823,6 +1170,7 @@ def main():
823
1170
  target_splits=splits,
824
1171
  pad_target_edges=args.pad_target_edges,
825
1172
  pad_fourier=args.pad_fourier,
1173
+ pad_template_filter=not args.no_filter_padding,
826
1174
  interpolation_order=args.interpolation_order,
827
1175
  )
828
1176
 
@@ -832,19 +1180,19 @@ def main():
832
1180
  candidates[0] *= target_mask.data
833
1181
  with warnings.catch_warnings():
834
1182
  warnings.simplefilter("ignore", category=UserWarning)
1183
+ nbytes = be.datatype_bytes(be._float_dtype)
1184
+ dtype = np.float32 if nbytes == 4 else np.float16
1185
+ rot_dim = matching_data.rotations.shape[1]
835
1186
  candidates[3] = {
836
- x: euler_from_rotationmatrix(
837
- np.frombuffer(i, dtype=matching_data.rotations.dtype).reshape(
838
- candidates[0].ndim, candidates[0].ndim
839
- )
840
- )
1187
+ x: np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
841
1188
  for i, x in candidates[3].items()
842
1189
  }
843
-
844
- candidates.append((target.origin, template.origin, target.sampling_rate, args))
1190
+ print(np.where(candidates[0] == candidates[0].max()), candidates[0].max())
1191
+ candidates.append((target.origin, template.origin, template.sampling_rate, args))
845
1192
  write_pickle(data=candidates, filename=args.output)
846
1193
 
847
1194
  runtime = time() - start
1195
+ print("\n" + "-" * 80)
848
1196
  print(f"\nRuntime real: {runtime:.3f}s user: {(runtime * args.cores):.3f}s.")
849
1197
 
850
1198