pytme 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2.dev0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pytme-0.3.2.dev0.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/match_template.py +213 -196
  3. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/postprocess.py +40 -78
  4. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/preprocess.py +4 -5
  5. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/preprocessor_gui.py +49 -103
  6. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/pytme_runner.py +46 -69
  7. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/METADATA +2 -1
  8. pytme-0.3.2.dev0.dist-info/RECORD +136 -0
  9. scripts/estimate_ram_usage.py +97 -0
  10. scripts/match_template.py +213 -196
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +40 -78
  13. scripts/preprocess.py +4 -5
  14. scripts/preprocessor_gui.py +49 -103
  15. scripts/pytme_runner.py +46 -69
  16. tests/preprocessing/test_compose.py +31 -30
  17. tests/preprocessing/test_frequency_filters.py +17 -32
  18. tests/preprocessing/test_preprocessor.py +0 -19
  19. tests/preprocessing/test_utils.py +13 -1
  20. tests/test_analyzer.py +2 -10
  21. tests/test_backends.py +47 -18
  22. tests/test_density.py +72 -13
  23. tests/test_extensions.py +1 -0
  24. tests/test_matching_cli.py +23 -9
  25. tests/test_matching_exhaustive.py +5 -5
  26. tests/test_matching_utils.py +3 -3
  27. tests/test_orientations.py +12 -0
  28. tests/test_rotations.py +13 -23
  29. tests/test_structure.py +1 -7
  30. tme/__version__.py +1 -1
  31. tme/analyzer/aggregation.py +47 -16
  32. tme/analyzer/base.py +34 -0
  33. tme/analyzer/peaks.py +26 -13
  34. tme/analyzer/proxy.py +14 -0
  35. tme/backends/_jax_utils.py +91 -68
  36. tme/backends/cupy_backend.py +6 -19
  37. tme/backends/jax_backend.py +103 -98
  38. tme/backends/matching_backend.py +0 -17
  39. tme/backends/mlx_backend.py +0 -29
  40. tme/backends/npfftw_backend.py +100 -97
  41. tme/backends/pytorch_backend.py +65 -78
  42. tme/cli.py +2 -2
  43. tme/density.py +44 -57
  44. tme/extensions.cpython-311-darwin.so +0 -0
  45. tme/filters/_utils.py +52 -24
  46. tme/filters/bandpass.py +99 -105
  47. tme/filters/compose.py +133 -39
  48. tme/filters/ctf.py +51 -102
  49. tme/filters/reconstruction.py +67 -122
  50. tme/filters/wedge.py +296 -325
  51. tme/filters/whitening.py +39 -75
  52. tme/mask.py +2 -2
  53. tme/matching_data.py +87 -15
  54. tme/matching_exhaustive.py +70 -120
  55. tme/matching_optimization.py +9 -63
  56. tme/matching_scores.py +261 -100
  57. tme/matching_utils.py +150 -91
  58. tme/memory.py +1 -0
  59. tme/orientations.py +17 -3
  60. tme/preprocessor.py +0 -239
  61. tme/rotations.py +102 -70
  62. tme/structure.py +601 -631
  63. tme/types.py +1 -0
  64. pytme-0.3.1.post2.dist-info/RECORD +0 -133
  65. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/estimate_memory_usage.py +0 -0
  66. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/WHEEL +0 -0
  67. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/entry_points.txt +0 -0
  68. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/licenses/LICENSE +0 -0
  69. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/top_level.txt +0 -0
scripts/match_template.py CHANGED
@@ -19,8 +19,9 @@ import numpy as np
19
19
 
20
20
  from tme.backends import backend as be
21
21
  from tme import Density, __version__, Orientations
22
- from tme.matching_utils import scramble_phases, write_pickle
23
- from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
22
+ from tme.matching_utils import write_pickle
23
+ from tme.matching_exhaustive import match_exhaustive
24
+ from tme.matching_scores import MATCHING_EXHAUSTIVE_REGISTER
24
25
  from tme.rotations import (
25
26
  get_cone_rotations,
26
27
  get_rotation_matrices,
@@ -37,6 +38,7 @@ from tme.filters import (
37
38
  Wedge,
38
39
  Compose,
39
40
  BandPass,
41
+ ShiftFourier,
40
42
  CTFReconstructed,
41
43
  WedgeReconstructed,
42
44
  ReconstructFromTilt,
@@ -129,21 +131,14 @@ def parse_rotation_logic(args, ndim):
129
131
 
130
132
 
131
133
  def compute_schedule(
132
- args,
133
- matching_data: MatchingData,
134
- callback_class,
135
- pad_edges: bool = False,
134
+ args, matching_data: MatchingData, callback_class, use_gpu: bool = False
136
135
  ):
137
- # User requested target padding
138
- if args.pad_edges is True:
139
- pad_edges = True
140
-
141
136
  splits, schedule = matching_data.computation_schedule(
142
137
  matching_method=args.score,
143
138
  analyzer_method=callback_class.__name__,
144
- use_gpu=args.use_gpu,
139
+ use_gpu=use_gpu,
145
140
  pad_fourier=False,
146
- pad_target_edges=pad_edges,
141
+ pad_target_edges=args.pad_edges,
147
142
  available_memory=args.memory,
148
143
  max_cores=args.cores,
149
144
  )
@@ -155,53 +150,63 @@ def compute_schedule(
155
150
  )
156
151
  exit(-1)
157
152
 
153
+ # Padding is required to avoid artifacts so setting it
158
154
  n_splits = np.prod(list(splits.values()))
159
- if pad_edges is False and len(matching_data._target_dim) == 0 and n_splits > 1:
155
+ if args.pad_edges is False and len(matching_data._target_dim) == 0 and n_splits > 1:
156
+ warnings.warn("Setting --pad-edges to avoid artifacts from splitting.")
160
157
  args.pad_edges = True
161
- return compute_schedule(args, matching_data, callback_class, True)
158
+ return compute_schedule(args, matching_data, callback_class, use_gpu)
162
159
  return splits, schedule
163
160
 
164
161
 
165
162
  def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
166
163
  template_filter, target_filter = [], []
167
164
 
168
- if args.tilt_angles is None:
169
- args.tilt_angles = args.ctf_file
170
-
171
165
  wedge = None
172
166
  if args.tilt_angles is not None:
173
167
  try:
174
168
  wedge = Wedge.from_file(args.tilt_angles)
175
169
  wedge.weight_type = args.tilt_weighting
176
- if args.tilt_weighting in ("angle", None):
170
+
171
+ # Avoid reconstructing the 3D wedge from individual tilts
172
+ if args.tilt_weighting in ("angle", None) and not args.match_projection:
177
173
  wedge = WedgeReconstructed(
178
174
  angles=wedge.angles,
179
175
  weight_wedge=args.tilt_weighting == "angle",
180
176
  )
177
+
181
178
  except (FileNotFoundError, AttributeError):
182
- tilt_start, tilt_stop = args.tilt_angles.split(",")
183
- tilt_start, tilt_stop = abs(float(tilt_start)), abs(float(tilt_stop))
184
179
  wedge = WedgeReconstructed(
185
- angles=(tilt_start, tilt_stop),
180
+ angles=args.tilt_angles,
186
181
  create_continuous_wedge=True,
187
182
  weight_wedge=False,
188
183
  reconstruction_filter=args.reconstruction_filter,
189
184
  )
190
- wedge.opening_axis, wedge.tilt_axis = args.wedge_axes
191
-
192
- wedge_target = WedgeReconstructed(
193
- angles=wedge.angles,
194
- weight_wedge=False,
195
- create_continuous_wedge=True,
196
- opening_axis=wedge.opening_axis,
197
- tilt_axis=wedge.tilt_axis,
198
- )
199
185
 
200
186
  wedge.sampling_rate = template.sampling_rate
201
- wedge_target.sampling_rate = template.sampling_rate
187
+ wedge.opening_axis, wedge.tilt_axis = args.wedge_axes
188
+ template_filter.append(wedge)
189
+
190
+ # When projection matching we can reuse the template wedge mask
191
+ wedge_target = wedge
192
+ if not args.match_projection:
193
+ wedge_target = WedgeReconstructed(
194
+ angles=wedge.angles,
195
+ weight_wedge=False,
196
+ create_continuous_wedge=True,
197
+ opening_axis=wedge.opening_axis,
198
+ tilt_axis=wedge.tilt_axis,
199
+ )
202
200
 
201
+ wedge_target.sampling_rate = template.sampling_rate
202
+ else:
203
+ n_angles, n_tilts = len(wedge_target.angles), target.shape[0]
204
+ if n_angles != n_tilts:
205
+ raise ValueError(
206
+ f"Target contains {n_tilts} tilts, but the input specified "
207
+ f"{n_angles} tilt angles."
208
+ )
203
209
  target_filter.append(wedge_target)
204
- template_filter.append(wedge)
205
210
 
206
211
  if args.ctf_file is not None or args.defocus is not None:
207
212
  try:
@@ -219,15 +224,19 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
219
224
  if len(ctf.angles) == 0:
220
225
  ctf.angles = wedge.angles
221
226
 
227
+ # There are several ways we can end up here. Bottom line, we are using
228
+ # a non-reconstructed wedge, which contains a different number of tilts
229
+ # than the ctf. We use defocus_x, as not all ctf_files specify angles.
222
230
  n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
223
- if (n_tilts_ctfs != n_tils_angles) and isinstance(wedge, Wedge):
231
+ if (n_tilts_ctfs != n_tils_angles) and type(wedge) is Wedge:
224
232
  raise ValueError(
225
233
  f"CTF file contains {n_tilts_ctfs} tilt, but recieved "
226
234
  f"{n_tils_angles} tilt angles. Expected one angle per tilt"
227
235
  )
228
236
 
229
237
  except (FileNotFoundError, AttributeError):
230
- ctf = CTFReconstructed(
238
+ ctf_cl = CTFReconstructed if not args.match_projection else CTF
239
+ ctf = ctf_cl(
231
240
  defocus_x=args.defocus,
232
241
  phase_shift=args.phase_shift,
233
242
  amplitude_contrast=args.amplitude_contrast,
@@ -237,10 +246,9 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
237
246
  ctf.flip_phase = args.no_flip_phase
238
247
  ctf.sampling_rate = template.sampling_rate
239
248
  ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
240
- ctf.correct_defocus_gradient = args.correct_defocus_gradient
241
249
  template_filter.append(ctf)
242
250
 
243
- if args.lowpass or args.highpass is not None:
251
+ if args.lowpass is not None or args.highpass is not None:
244
252
  lowpass, highpass = args.lowpass, args.highpass
245
253
  if args.pass_format == "voxel":
246
254
  if lowpass is not None:
@@ -255,49 +263,58 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
255
263
 
256
264
  try:
257
265
  if args.lowpass >= args.highpass:
258
- warnings.warn("--lowpass should be smaller than --highpass.")
266
+ raise ValueError("--lowpass should be smaller than --highpass.")
259
267
  except Exception:
260
268
  pass
261
269
 
262
- bandpass = BandPassReconstructed(
270
+ bp_cl = BandPassReconstructed if not args.match_projection else BandPass
271
+ bandpass = bp_cl(
263
272
  use_gaussian=args.no_pass_smooth,
264
273
  lowpass=lowpass,
265
274
  highpass=highpass,
266
275
  sampling_rate=template.sampling_rate,
267
276
  )
277
+ bandpass.opening_axis, bandpass.tilt_axis = args.wedge_axes
268
278
  template_filter.append(bandpass)
269
279
  target_filter.append(bandpass)
270
280
 
281
+ if not args.match_projection:
282
+ rec_filt = (Wedge, CTF)
283
+ needs_reconstruction = sum(type(x) in rec_filt for x in template_filter)
284
+ if needs_reconstruction > 0 and args.reconstruction_filter is None:
285
+ warnings.warn(
286
+ "Consider using a --reconstruction_filter such as 'ram-lak' or 'ramp' "
287
+ "to avoid artifacts from reconstruction using weighted backprojection."
288
+ )
289
+
290
+ template_filter = sorted(
291
+ template_filter, key=lambda x: type(x) in rec_filt, reverse=True
292
+ )
293
+ if needs_reconstruction > 0:
294
+ relevant_filters = [x for x in template_filter if type(x) in rec_filt]
295
+ if len(relevant_filters) == 0:
296
+ raise ValueError("Filters require ")
297
+
298
+ reconstruction_filter = ReconstructFromTilt(
299
+ reconstruction_filter=args.reconstruction_filter,
300
+ interpolation_order=args.reconstruction_interpolation_order,
301
+ angles=relevant_filters[0].angles,
302
+ opening_axis=args.wedge_axes[0],
303
+ tilt_axis=args.wedge_axes[1],
304
+ )
305
+ template_filter.insert(needs_reconstruction, reconstruction_filter)
306
+ else:
307
+ template_filter.append(ShiftFourier())
308
+ if len(target_filter):
309
+ target_filter.append(ShiftFourier())
310
+
311
+ # LinearWhiteningFilter does not support working on tilts yet, hence we
312
+ # can safely evaluate it after all other filters
271
313
  if args.whiten_spectrum:
272
314
  whitening_filter = LinearWhiteningFilter()
273
315
  template_filter.append(whitening_filter)
274
316
  target_filter.append(whitening_filter)
275
317
 
276
- rec_filt = (Wedge, CTF)
277
- needs_reconstruction = sum(type(x) in rec_filt for x in template_filter)
278
- if needs_reconstruction > 0 and args.reconstruction_filter is None:
279
- warnings.warn(
280
- "Consider using a --reconstruction_filter such as 'ram-lak' or 'ramp' "
281
- "to avoid artifacts from reconstruction using weighted backprojection."
282
- )
283
-
284
- template_filter = sorted(
285
- template_filter, key=lambda x: type(x) in rec_filt, reverse=True
286
- )
287
- if needs_reconstruction > 0:
288
- relevant_filters = [x for x in template_filter if type(x) in rec_filt]
289
- if len(relevant_filters) == 0:
290
- raise ValueError("Filters require ")
291
-
292
- reconstruction_filter = ReconstructFromTilt(
293
- reconstruction_filter=args.reconstruction_filter,
294
- interpolation_order=args.reconstruction_interpolation_order,
295
- angles=relevant_filters[0].angles,
296
- opening_axis=args.wedge_axes[0],
297
- tilt_axis=args.wedge_axes[1],
298
- )
299
- template_filter.insert(needs_reconstruction, reconstruction_filter)
300
-
301
318
  template_filter = Compose(template_filter) if len(template_filter) else None
302
319
  target_filter = Compose(target_filter) if len(target_filter) else None
303
320
  if args.no_filter_target:
@@ -359,8 +376,7 @@ def parse_args():
359
376
  "--invert-target-contrast",
360
377
  action="store_true",
361
378
  default=False,
362
- help="Invert the target contrast. Useful for matching on tomograms if the "
363
- "template has not been inverted.",
379
+ help="Invert contrast by multiplication with negative one.",
364
380
  )
365
381
  io_group.add_argument(
366
382
  "--scramble-phases",
@@ -411,6 +427,13 @@ def parse_args():
411
427
  choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
412
428
  help="Template matching scoring function.",
413
429
  )
430
+ scoring_group.add_argument(
431
+ "--background-correction",
432
+ choices=["phase-scrambling"],
433
+ required=False,
434
+ help="Transform cross-correlation into SNR-like values using a given method: "
435
+ "'phase-scrambling' uses a phase-scrambled template as background",
436
+ )
414
437
 
415
438
  angular_group = parser.add_argument_group("Angular Sampling")
416
439
  angular_exclusive = angular_group.add_mutually_exclusive_group(required=True)
@@ -493,9 +516,8 @@ def parse_args():
493
516
  computation_group.add_argument(
494
517
  "--gpu-indices",
495
518
  type=str,
496
- default=None,
497
- help="Comma-separated GPU indices (e.g., '0,1,2' for first 3 GPUs). Otherwise "
498
- "CUDA_VISIBLE_DEVICES will be used.",
519
+ default=os.environ.get("CUDA_VISIBLE_DEVICES"),
520
+ help="GPU indices, e.g., '0,1,2', defaults to CUDA_VISIBLE_DEVICES.",
499
521
  )
500
522
  computation_group.add_argument(
501
523
  "--memory",
@@ -527,15 +549,13 @@ def parse_args():
527
549
  "--lowpass",
528
550
  type=float,
529
551
  required=False,
530
- help="Resolution to lowpass filter template and target to in the same unit "
531
- "as the sampling rate of template and target (typically Ångstrom).",
552
+ help="Resolution to lowpass filter template and target to.",
532
553
  )
533
554
  filter_group.add_argument(
534
555
  "--highpass",
535
556
  type=float,
536
557
  required=False,
537
- help="Resolution to highpass filter template and target to in the same unit "
538
- "as the sampling rate of template and target (typically Ångstrom).",
558
+ help="Resolution to highpass filter template and target to.",
539
559
  )
540
560
  filter_group.add_argument(
541
561
  "--no-pass-smooth",
@@ -549,14 +569,13 @@ def parse_args():
549
569
  required=False,
550
570
  default="sampling_rate",
551
571
  choices=["sampling_rate", "voxel", "frequency"],
552
- help="How values passed to --lowpass and --highpass should be interpreted. "
553
- "Defaults to unit of sampling_rate, e.g., 40 Angstrom.",
572
+ help="How values passed to --lowpass and --highpass should be interpreted. ",
554
573
  )
555
574
  filter_group.add_argument(
556
575
  "--whiten-spectrum",
557
576
  action="store_true",
558
577
  default=None,
559
- help="Apply spectral whitening to template and target based on target spectrum.",
578
+ help="Apply spectral whitening to template and target.",
560
579
  )
561
580
  filter_group.add_argument(
562
581
  "--wedge-axes",
@@ -603,6 +622,7 @@ def parse_args():
603
622
  type=int,
604
623
  default=1,
605
624
  required=False,
625
+ choices=[0, 1, 2, 3, 4, 5],
606
626
  help="Analogous to --interpolation-order but for reconstruction.",
607
627
  )
608
628
  filter_group.add_argument(
@@ -664,40 +684,32 @@ def parse_args():
664
684
  required=False,
665
685
  help="Do not perform phase-flipping CTF correction.",
666
686
  )
667
- ctf_group.add_argument(
668
- "--correct-defocus-gradient",
669
- action="store_true",
670
- required=False,
671
- help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
672
- "defocus gradients.",
673
- )
674
687
 
675
688
  performance_group = parser.add_argument_group("Performance")
676
689
  performance_group.add_argument(
677
690
  "--centering",
678
691
  action="store_true",
679
- help="Center the template in the box if it has not been done already.",
692
+ help="Translate the template's center of mass to the center of the box.",
680
693
  )
681
694
  performance_group.add_argument(
682
695
  "--pad-edges",
683
696
  action="store_true",
684
697
  default=False,
685
- help="Useful if the target does not have a well-defined bounding box. Will be "
686
- "activated automatically if splitting is required to avoid boundary artifacts.",
698
+ help="Zero pad the target. Defaults to True if splitting is required..",
687
699
  )
688
700
  performance_group.add_argument(
689
701
  "--interpolation-order",
690
702
  required=False,
691
703
  type=int,
692
704
  default=None,
693
- help="Spline interpolation used for rotations. Defaults to 3, and 1 for jax "
694
- "and pytorch backends.",
705
+ choices=[0, 1, 2, 3, 4, 5],
706
+ help="Spline order for rotation, default is 3 and 1 for jax and pytorch.",
695
707
  )
696
708
  performance_group.add_argument(
697
709
  "--use-memmap",
698
710
  action="store_true",
699
711
  default=False,
700
- help="Memmap large data to disk, e.g., matching on unbinned tomograms.",
712
+ help="Memmap analyzer data, useful for matching on very large inputs.",
701
713
  )
702
714
 
703
715
  analyzer_group = parser.add_argument_group("Output / Analysis")
@@ -706,7 +718,7 @@ def parse_args():
706
718
  required=False,
707
719
  type=float,
708
720
  default=0,
709
- help="Minimum template matching scores to consider for analysis.",
721
+ help="Minimum template matching scores to consider.",
710
722
  )
711
723
  analyzer_group.add_argument(
712
724
  "-p",
@@ -724,30 +736,20 @@ def parse_args():
724
736
  args = parser.parse_args()
725
737
  args.version = __version__
726
738
 
727
- if args.interpolation_order is None:
728
- args.interpolation_order = 3
729
- if args.backend in ("jax", "pytorch"):
730
- args.interpolation_order = 1
731
- args.reconstruction_interpolation_order = 1
732
-
733
- if args.interpolation_order < 0:
734
- args.interpolation_order = None
735
-
736
- if args.temp_directory is None:
737
- args.temp_directory = gettempdir()
738
-
739
- os.environ["TMPDIR"] = args.temp_directory
740
- if args.gpu_indices is not None:
741
- os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_indices
739
+ if args.temp_directory is not None:
740
+ os.environ["TMPDIR"] = args.temp_directory
742
741
 
743
- if args.tilt_angles is not None and not exists(args.tilt_angles):
742
+ # Tilt angles can be specified as range or using a suitable input file
743
+ is_file = exists(args.tilt_angles) if args.tilt_angles is not None else False
744
+ if args.tilt_angles is not None and not is_file:
744
745
  try:
745
- float(args.tilt_angles.split(",")[0])
746
+ args.tilt_angles = tuple(abs(float(x)) for x in args.tilt_angles.split(","))
746
747
  except Exception:
747
748
  raise ValueError(f"{args.tilt_angles} is not a file nor a range.")
748
749
 
750
+ # Since both Wedge.from_file and CTF.from_file parse similar inputs, we can
751
+ # fall back to assigning the ctf_file to args.tilt_angles
749
752
  if args.ctf_file is not None and args.tilt_angles is None:
750
- # Check if tilt angles can be extracted from CTF specification
751
753
  try:
752
754
  ctf = CTF.from_file(args.ctf_file)
753
755
  if ctf.angles is None:
@@ -758,7 +760,14 @@ def parse_args():
758
760
  "Need to specify --tilt-angles when not provided in --ctf-file."
759
761
  )
760
762
 
761
- args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
763
+ # For projection matching we cannot use continuous wedge masks
764
+ args.match_projection = False
765
+ if not is_file and args.match_projection:
766
+ raise ValueError(
767
+ "Projection angles are required via --tilt-angles or --ctf-file."
768
+ )
769
+
770
+ # Handle constrained matching inputs
762
771
  if args.orientations is not None:
763
772
  orientations = Orientations.from_file(args.orientations)
764
773
  orientations.translations = np.divide(
@@ -766,6 +775,64 @@ def parse_args():
766
775
  )
767
776
  args.orientations = orientations
768
777
 
778
+ if args.orientations_uncertainty is not None:
779
+ args.orientations_uncertainty = tuple(
780
+ int(x) for x in args.orientations_uncertainty.split(",")
781
+ )
782
+
783
+ # Handle backend specificities
784
+ if args.interpolation_order is None:
785
+ args.interpolation_order = 3
786
+ if args.backend in ("jax", "pytorch"):
787
+ args.interpolation_order = 1
788
+ args.reconstruction_interpolation_order = 1
789
+
790
+ # This flag is not passed to backend yet, but might aswell be verbose about it
791
+ if args.interpolation_order != 1 and args.backend == "jax":
792
+ warnings.warn("Setting interpolation order to order jax supports (1).")
793
+ args.interpolation_order = 1
794
+ args.reconstruction_interpolation_order = 1
795
+
796
+ if args.interpolation_order == 3 and args.backend == "pytorch":
797
+ warnings.warn("Pytorch does not support order 3, changing it to 1.")
798
+ args.interpolation_order = 1
799
+ if args.reconstruction_interpolation_order == 3:
800
+ args.reconstruction_interpolation_order = args.interpolation_order
801
+
802
+ # Handle GPU device specification for suitable backends
803
+ if args.backend in ("pytorch", "cupy", "jax"):
804
+ if args.gpu_indices is None:
805
+ warnings.warn(
806
+ "No GPU indices provided and CUDA_VISIBLE_DEVICES is not set. "
807
+ "Assuming device 0.",
808
+ )
809
+ args.gpu_indices = "0"
810
+
811
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_indices
812
+ args.gpu_indices = [int(x) for x in args.gpu_indices.split(",")]
813
+ args.cores = len(args.gpu_indices)
814
+
815
+ if args.backend == "jax" and args.peak_calling:
816
+ raise ValueError("Jax supports only subclasses of MaxScoreOverRotations.")
817
+
818
+ # Wedge axes do not have meaning for projections
819
+ args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
820
+ if args.match_projection:
821
+ args.wedge_axes = None, None
822
+
823
+ if args.match_projection and args.backend != "jax":
824
+ raise ValueError("Projection matching is only supported for --backend jax.")
825
+
826
+ # This is implicitly caught in the jax check above, but keeping it for future use
827
+ if args.match_projection and args.peak_calling:
828
+ raise ValueError("Peak calling is not yet supported for projection matching.")
829
+
830
+ if args.orientations is not None and args.peak_calling:
831
+ raise ValueError(
832
+ "Peak calling and constrained matching simultaneously is not yet supported."
833
+ )
834
+
835
+ # Avoid relative input specification
769
836
  args.target = abspath(args.target)
770
837
  if args.target_mask is not None:
771
838
  args.target_mask = abspath(args.target_mask)
@@ -782,7 +849,6 @@ def main():
782
849
  print_entry()
783
850
 
784
851
  target = Density.from_file(args.target, use_memmap=True)
785
-
786
852
  try:
787
853
  template = Density.from_file(args.template)
788
854
  except Exception:
@@ -800,27 +866,24 @@ def main():
800
866
  )
801
867
 
802
868
  if target.sampling_rate.size == template.sampling_rate.size:
803
- if not np.allclose(
869
+ sampling_rate_match = np.allclose(
804
870
  np.round(target.sampling_rate, 2), np.round(template.sampling_rate, 2)
805
- ):
871
+ )
872
+ # For projection we omit the warning as the leading dimension has no sampling
873
+ if not sampling_rate_match and not args.match_projection:
806
874
  warnings.warn(
807
875
  f"Sampling rate mismatch detected: target={target.sampling_rate} "
808
876
  f"template={template.sampling_rate}. Proceeding with user-provided "
809
877
  f"values. Make sure this is intentional. "
810
878
  )
811
879
 
812
- template_mask = load_and_validate_mask(
813
- mask_target=template, mask_path=args.template_mask
814
- )
815
- target_mask = load_and_validate_mask(
816
- mask_target=target, mask_path=args.target_mask, use_memmap=True
817
- )
880
+ template_mask = load_and_validate_mask(template, args.template_mask)
881
+ target_mask = load_and_validate_mask(target, args.target_mask, use_memmap=True)
818
882
 
819
- initial_shape = target.shape
820
883
  print_block(
821
884
  name="Target",
822
885
  data={
823
- "Initial Shape": initial_shape,
886
+ "Initial Shape": target.shape,
824
887
  "Sampling Rate": _format_sampling(target.sampling_rate),
825
888
  "Final Shape": target.shape,
826
889
  },
@@ -830,16 +893,15 @@ def main():
830
893
  print_block(
831
894
  name="Target Mask",
832
895
  data={
833
- "Initial Shape": initial_shape,
896
+ "Initial Shape": target_mask.shape,
834
897
  "Sampling Rate": _format_sampling(target_mask.sampling_rate),
835
898
  "Final Shape": target_mask.shape,
836
899
  },
837
900
  )
838
901
 
839
902
  initial_shape = template.shape
840
- translation = np.zeros(len(template.shape), dtype=np.float32)
841
903
  if args.centering:
842
- template, translation = template.centered(0)
904
+ template = template.centered(0)
843
905
 
844
906
  print_block(
845
907
  name="Template",
@@ -852,27 +914,12 @@ def main():
852
914
 
853
915
  if template_mask is None:
854
916
  template_mask = template.empty
855
- if not args.centering:
856
- enclosing_box = template.minimum_enclosing_box(
857
- 0, use_geometric_center=False
858
- )
859
- template_mask.adjust_box(enclosing_box)
860
-
861
- template_mask.data[:] = 1
862
- translation = np.zeros_like(translation)
863
917
 
864
- template_mask.pad(template.shape, center=False)
865
- origin_translation = np.divide(
866
- np.subtract(template.origin, template_mask.origin), template.sampling_rate
867
- )
868
- translation = np.add(translation, origin_translation)
918
+ # Pre 0.3.2 we used to perform a rigid transform on the template mask to match
919
+ # the template origin, but this seems overly pedantic given the sporadic use
920
+ # of the origin parameter in the matching pipeline
921
+ template_mask.data = np.ones(template.shape, dtype=template.data.dtype)
869
922
 
870
- template_mask = template_mask.rigid_transform(
871
- rotation_matrix=np.eye(template_mask.data.ndim),
872
- translation=-translation,
873
- order=1,
874
- )
875
- template_mask.origin = template.origin.copy()
876
923
  print_block(
877
924
  name="Template Mask",
878
925
  data={
@@ -883,71 +930,35 @@ def main():
883
930
  )
884
931
  print("\n" + "-" * 80)
885
932
 
886
- if args.scramble_phases:
887
- template.data = scramble_phases(template.data, noise_proportion=1.0)
888
-
889
933
  callback_class = MaxScoreOverRotations
890
934
  if args.orientations is not None:
891
935
  callback_class = MaxScoreOverRotationsConstrained
892
936
  elif args.peak_calling:
893
937
  callback_class = PeakCallerMaximumFilter
894
938
 
895
- # Determine suitable backend for the selected operation
896
- available_backends = be.available_backends()
897
- if args.backend not in available_backends:
898
- raise ValueError("Requested backend is not available.")
899
- if args.backend == "jax" and callback_class != MaxScoreOverRotations:
900
- raise ValueError(
901
- "Jax backend only supports the MaxScoreOverRotations analyzer."
902
- )
903
-
904
- if args.interpolation_order == 3 and args.backend in ("jax", "pytorch"):
905
- warnings.warn(
906
- "Jax and pytorch do not support interpolation order 3, setting it to 1."
907
- )
908
- args.interpolation_order = 1
909
-
910
- if args.backend in ("pytorch", "cupy", "jax"):
911
- gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
912
- if gpu_devices is None:
913
- warnings.warn(
914
- "No GPU indices provided and CUDA_VISIBLE_DEVICES is not set. "
915
- "Assuming device 0.",
916
- )
917
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
918
-
919
- args.cores = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
920
- args.gpu_indices = [
921
- int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
922
- ]
923
-
924
- # Finally set the desired backend
925
- device = "cuda"
926
- args.use_gpu = False
927
- be.change_backend(args.backend)
939
+ # We currently do not allow parallelizing angular searches in the GPU compatible
940
+ # backends, so we keep this flag to compute a suitable splitting schedule
941
+ use_gpu = False
928
942
  if args.backend in ("jax", "pytorch", "cupy"):
929
- args.use_gpu = True
943
+ use_gpu = True
930
944
 
945
+ # Finally set the requested backend
946
+ be.change_backend(args.backend)
931
947
  if args.backend == "pytorch":
932
948
  try:
933
- be.change_backend("pytorch", device=device)
949
+ be.change_backend("pytorch", device="cuda")
934
950
  # Trigger exception if not compiled with device
935
951
  be.get_available_memory()
936
952
  except Exception as e:
953
+ # Let the user know they did not compile with GPU devices
937
954
  print(e)
938
- device = "cpu"
939
- args.use_gpu = False
940
- be.change_backend("pytorch", device=device)
955
+ use_gpu = False
956
+ be.change_backend("pytorch", device="cpu")
941
957
 
942
958
  available_memory = be.get_available_memory() * be.device_count()
943
959
  if args.memory is None:
944
960
  args.memory = int(args.memory_scaling * available_memory)
945
961
 
946
- if args.orientations_uncertainty is not None:
947
- args.orientations_uncertainty = tuple(
948
- int(x) for x in args.orientations_uncertainty.split(",")
949
- )
950
-
951
962
  matching_data = MatchingData(
952
963
  target=target,
953
964
  template=template.data,
@@ -956,19 +967,23 @@ def main():
956
967
  invert_target=args.invert_target_contrast,
957
968
  rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
958
969
  )
959
-
960
- matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
961
- matching_data.template_filter, matching_data.target_filter = setup_filter(
962
- args, template, target
963
- )
970
+ if args.scramble_phases:
971
+ matching_data.template = matching_data.transform_template("phase_randomization")
964
972
 
965
973
  matching_data.set_matching_dimension(
966
974
  target_dim=target.metadata.get("batch_dimension", None),
967
975
  template_dim=template.metadata.get("batch_dimension", None),
968
976
  )
977
+ if args.match_projection:
978
+ matching_data.set_matching_dimension(target_dim=0)
979
+
969
980
  args.batch_dims = tuple(int(x) for x in np.where(matching_data._batch_mask)[0])
981
+ matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
982
+ matching_data.template_filter, matching_data.target_filter = setup_filter(
983
+ args, template, target
984
+ )
970
985
 
971
- splits, schedule = compute_schedule(args, matching_data, callback_class)
986
+ splits, schedule = compute_schedule(args, matching_data, callback_class, use_gpu)
972
987
 
973
988
  n_splits = np.prod(list(splits.values()))
974
989
  target_split = ", ".join(
@@ -980,6 +995,7 @@ def main():
980
995
  f" [{matching_data.rotations.shape[0]} rotations]",
981
996
  "Center Template": args.centering,
982
997
  "Scramble Template": args.scramble_phases,
998
+ "Background Correction": args.background_correction,
983
999
  "Invert Contrast": args.invert_target_contrast,
984
1000
  "Extend Target Edges": args.pad_edges,
985
1001
  "Interpolation Order": args.interpolation_order,
@@ -1020,7 +1036,6 @@ def main():
1020
1036
  if args.ctf_file is not None or args.defocus is not None:
1021
1037
  filter_args["CTF File"] = args.ctf_file
1022
1038
  filter_args["Flip Phase"] = args.no_flip_phase
1023
- filter_args["Correct Defocus"] = args.correct_defocus_gradient
1024
1039
 
1025
1040
  filter_args = {k: v for k, v in filter_args.items() if v is not None}
1026
1041
  if len(filter_args):
@@ -1042,7 +1057,7 @@ def main():
1042
1057
  analyzer_args["acceptance_radius"] = args.orientations_uncertainty
1043
1058
  analyzer_args["positions"] = args.orientations.translations
1044
1059
  analyzer_args["rotations"] = euler_to_rotationmatrix(
1045
- args.orientations.rotations
1060
+ args.orientations.rotations, seq="ZYZ"
1046
1061
  )
1047
1062
 
1048
1063
  print_block(
@@ -1062,7 +1077,7 @@ def main():
1062
1077
 
1063
1078
  start = time()
1064
1079
  print("Running Template Matching. This might take a while ...")
1065
- candidates = scan_subsets(
1080
+ candidates = match_exhaustive(
1066
1081
  matching_data=matching_data,
1067
1082
  job_schedule=schedule,
1068
1083
  matching_score=matching_score,
@@ -1072,6 +1087,8 @@ def main():
1072
1087
  target_splits=splits,
1073
1088
  pad_target_edges=args.pad_edges,
1074
1089
  interpolation_order=args.interpolation_order,
1090
+ match_projection=args.match_projection,
1091
+ background_correction=args.background_correction,
1075
1092
  )
1076
1093
 
1077
1094
  candidates = list(candidates) if candidates is not None else []