pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py → pytme-0.3b0.data/scripts/estimate_memory_usage.py +16 -33
  2. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/match_template.py +224 -223
  3. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/postprocess.py +283 -163
  4. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/preprocess.py +11 -8
  5. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/preprocessor_gui.py +10 -9
  6. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/METADATA +11 -9
  7. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/RECORD +61 -58
  8. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/entry_points.txt +1 -1
  9. scripts/{estimate_ram_usage.py → estimate_memory_usage.py} +16 -33
  10. scripts/extract_candidates.py +224 -0
  11. scripts/match_template.py +224 -223
  12. scripts/postprocess.py +283 -163
  13. scripts/preprocess.py +11 -8
  14. scripts/preprocessor_gui.py +10 -9
  15. scripts/refine_matches.py +626 -0
  16. tests/preprocessing/test_frequency_filters.py +9 -4
  17. tests/test_analyzer.py +143 -138
  18. tests/test_matching_cli.py +85 -29
  19. tests/test_matching_exhaustive.py +1 -2
  20. tests/test_matching_optimization.py +4 -9
  21. tests/test_orientations.py +0 -1
  22. tme/__version__.py +1 -1
  23. tme/analyzer/__init__.py +2 -0
  24. tme/analyzer/_utils.py +25 -17
  25. tme/analyzer/aggregation.py +385 -220
  26. tme/analyzer/base.py +138 -0
  27. tme/analyzer/peaks.py +150 -88
  28. tme/analyzer/proxy.py +122 -0
  29. tme/backends/__init__.py +4 -3
  30. tme/backends/_cupy_utils.py +25 -24
  31. tme/backends/_jax_utils.py +4 -3
  32. tme/backends/cupy_backend.py +4 -13
  33. tme/backends/jax_backend.py +6 -8
  34. tme/backends/matching_backend.py +4 -3
  35. tme/backends/mlx_backend.py +4 -3
  36. tme/backends/npfftw_backend.py +7 -5
  37. tme/backends/pytorch_backend.py +14 -4
  38. tme/cli.py +126 -0
  39. tme/density.py +4 -3
  40. tme/filters/__init__.py +1 -1
  41. tme/filters/_utils.py +4 -3
  42. tme/filters/bandpass.py +6 -4
  43. tme/filters/compose.py +5 -4
  44. tme/filters/ctf.py +426 -214
  45. tme/filters/reconstruction.py +58 -28
  46. tme/filters/wedge.py +139 -61
  47. tme/filters/whitening.py +36 -36
  48. tme/matching_data.py +4 -3
  49. tme/matching_exhaustive.py +17 -16
  50. tme/matching_optimization.py +5 -4
  51. tme/matching_scores.py +4 -3
  52. tme/matching_utils.py +6 -4
  53. tme/memory.py +4 -3
  54. tme/orientations.py +9 -6
  55. tme/parser.py +5 -4
  56. tme/preprocessor.py +4 -3
  57. tme/rotations.py +10 -7
  58. tme/structure.py +4 -3
  59. tests/data/Maps/.DS_Store +0 -0
  60. tests/data/Structures/.DS_Store +0 -0
  61. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/WHEEL +0 -0
  62. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/licenses/LICENSE +0 -0
  63. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/top_level.txt +0 -0
scripts/match_template.py CHANGED
@@ -1,9 +1,9 @@
1
1
  #!python3
2
- """ CLI for basic pyTME template matching functions.
2
+ """CLI for basic pyTME template matching functions.
3
3
 
4
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
5
 
6
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
  import os
9
9
  import argparse
@@ -18,61 +18,31 @@ from tempfile import gettempdir
18
18
  import numpy as np
19
19
 
20
20
  from tme.backends import backend as be
21
- from tme import Density, __version__
21
+ from tme import Density, __version__, Orientations
22
22
  from tme.matching_utils import scramble_phases, write_pickle
23
23
  from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
24
24
  from tme.rotations import (
25
25
  get_cone_rotations,
26
26
  get_rotation_matrices,
27
+ euler_to_rotationmatrix,
27
28
  )
28
29
  from tme.matching_data import MatchingData
29
30
  from tme.analyzer import (
30
31
  MaxScoreOverRotations,
31
32
  PeakCallerMaximumFilter,
33
+ MaxScoreOverRotationsConstrained,
32
34
  )
33
35
  from tme.filters import (
34
36
  CTF,
35
37
  Wedge,
36
38
  Compose,
37
39
  BandPassFilter,
40
+ CTFReconstructed,
38
41
  WedgeReconstructed,
39
42
  ReconstructFromTilt,
40
43
  LinearWhiteningFilter,
41
44
  )
42
-
43
-
44
- def get_func_fullname(func) -> str:
45
- """Returns the full name of the given function, including its module."""
46
- return f"<function '{func.__module__}.{func.__name__}'>"
47
-
48
-
49
- def print_block(name: str, data: dict, label_width=20) -> None:
50
- """Prints a formatted block of information."""
51
- print(f"\n> {name}")
52
- for key, value in data.items():
53
- if isinstance(value, np.ndarray):
54
- value = value.shape
55
- formatted_value = str(value)
56
- print(f" - {key + ':':<{label_width}} {formatted_value}")
57
-
58
-
59
- def print_entry() -> None:
60
- width = 80
61
- text = f" pytme v{__version__} "
62
- padding_total = width - len(text) - 2
63
- padding_left = padding_total // 2
64
- padding_right = padding_total - padding_left
65
-
66
- print("*" * width)
67
- print(f"*{ ' ' * padding_left }{text}{ ' ' * padding_right }*")
68
- print("*" * width)
69
-
70
-
71
- def check_positive(value):
72
- ivalue = float(value)
73
- if ivalue <= 0:
74
- raise argparse.ArgumentTypeError("%s is an invalid positive float." % value)
75
- return ivalue
45
+ from tme.cli import get_func_fullname, print_block, print_entry, check_positive
76
46
 
77
47
 
78
48
  def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
@@ -118,6 +88,14 @@ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
118
88
 
119
89
 
120
90
  def parse_rotation_logic(args, ndim):
91
+ if args.particle_diameter is not None:
92
+ resolution = Density.from_file(args.target, use_memmap=True)
93
+ resolution = 360 * np.maximum(
94
+ np.max(2 * resolution.sampling_rate),
95
+ args.lowpass if args.lowpass is not None else 0,
96
+ )
97
+ args.angular_sampling = resolution / (3.14159265358979 * args.particle_diameter)
98
+
121
99
  if args.angular_sampling is not None:
122
100
  rotations = get_rotation_matrices(
123
101
  angular_sampling=args.angular_sampling,
@@ -178,123 +156,72 @@ def compute_schedule(
178
156
 
179
157
 
180
158
  def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
181
- needs_reconstruction = False
182
159
  template_filter, target_filter = [], []
160
+
161
+ wedge = None
183
162
  if args.tilt_angles is not None:
184
- needs_reconstruction = args.tilt_weighting is not None
185
163
  try:
186
164
  wedge = Wedge.from_file(args.tilt_angles)
187
165
  wedge.weight_type = args.tilt_weighting
188
- if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
166
+ if args.tilt_weighting in ("angle", None):
189
167
  wedge = WedgeReconstructed(
190
168
  angles=wedge.angles,
191
169
  weight_wedge=args.tilt_weighting == "angle",
192
- opening_axis=args.wedge_axes[0],
193
- tilt_axis=args.wedge_axes[1],
194
170
  )
195
- except FileNotFoundError:
196
- tilt_step, create_continuous_wedge = None, True
171
+ except (FileNotFoundError, AttributeError):
197
172
  tilt_start, tilt_stop = args.tilt_angles.split(",")
198
- if ":" in tilt_stop:
199
- create_continuous_wedge = False
200
- tilt_stop, tilt_step = tilt_stop.split(":")
201
- tilt_start, tilt_stop = float(tilt_start), float(tilt_stop)
202
- tilt_angles = (tilt_start, tilt_stop)
203
- if tilt_step is not None:
204
- tilt_step = float(tilt_step)
205
- tilt_angles = np.arange(
206
- -tilt_start, tilt_stop + tilt_step, tilt_step
207
- ).tolist()
208
-
209
- if args.tilt_weighting is not None and tilt_step is None:
210
- raise ValueError(
211
- "Tilt weighting is not supported for continuous wedges."
212
- )
213
- if args.tilt_weighting not in ("angle", None):
214
- raise ValueError(
215
- "Tilt weighting schemes other than 'angle' or 'None' require "
216
- "a specification of electron doses via --tilt_angles."
217
- )
218
-
219
- wedge = Wedge(
220
- angles=tilt_angles,
221
- opening_axis=args.wedge_axes[0],
222
- tilt_axis=args.wedge_axes[1],
223
- shape=None,
224
- weight_type=None,
225
- weights=np.ones_like(tilt_angles),
173
+ tilt_start, tilt_stop = abs(float(tilt_start)), abs(float(tilt_stop))
174
+ wedge = WedgeReconstructed(
175
+ angles=(tilt_start, tilt_stop),
176
+ create_continuous_wedge=True,
177
+ weight_wedge=False,
178
+ reconstruction_filter=args.reconstruction_filter,
226
179
  )
227
- if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
228
- wedge = WedgeReconstructed(
229
- angles=tilt_angles,
230
- weight_wedge=args.tilt_weighting == "angle",
231
- create_continuous_wedge=create_continuous_wedge,
232
- reconstruction_filter=args.reconstruction_filter,
233
- opening_axis=args.wedge_axes[0],
234
- tilt_axis=args.wedge_axes[1],
235
- )
236
- wedge_target = WedgeReconstructed(
237
- angles=(np.abs(np.min(tilt_angles)), np.abs(np.max(tilt_angles))),
238
- weight_wedge=False,
239
- create_continuous_wedge=True,
240
- opening_axis=args.wedge_axes[0],
241
- tilt_axis=args.wedge_axes[1],
242
- )
243
- target_filter.append(wedge_target)
244
180
 
245
- wedge.sampling_rate = template.sampling_rate
181
+ wedge_target = WedgeReconstructed(
182
+ angles=wedge.angles,
183
+ weight_wedge=False,
184
+ create_continuous_wedge=True,
185
+ opening_axis=args.wedge_axes[0],
186
+ tilt_axis=args.wedge_axes[1],
187
+ )
188
+ wedge.opening_axis = args.wedge_axes[0]
189
+ wedge.tilt_axis = args.wedge_axes[1]
190
+
191
+ target_filter.append(wedge_target)
246
192
  template_filter.append(wedge)
247
- if not isinstance(wedge, WedgeReconstructed):
248
- reconstruction_filter = ReconstructFromTilt(
249
- reconstruction_filter=args.reconstruction_filter,
250
- interpolation_order=args.reconstruction_interpolation_order,
251
- )
252
- template_filter.append(reconstruction_filter)
253
193
 
194
+ args.ctf_file is not None
254
195
  if args.ctf_file is not None or args.defocus is not None:
255
- needs_reconstruction = True
256
- if args.ctf_file is not None:
196
+ try:
257
197
  ctf = CTF.from_file(args.ctf_file)
198
+ if (len(ctf.angles) == 0) and wedge is None:
199
+ raise ValueError(
200
+ "You requested to specify the CTF per tilt, but did not specify "
201
+ "tilt angles via --tilt_angles or --ctf_file (Warp/M XML format). "
202
+ )
203
+ if len(ctf.angles) == 0:
204
+ ctf.angles = wedge.angles
205
+
258
206
  n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
259
- if n_tilts_ctfs != n_tils_angles:
207
+ if (n_tilts_ctfs != n_tils_angles) and isinstance(wedge, Wedge):
260
208
  raise ValueError(
261
- f"CTF file contains {n_tilts_ctfs} micrographs, but match_template "
209
+ f"CTF file contains {n_tilts_ctfs} tilt, but match_template "
262
210
  f"recieved {n_tils_angles} tilt angles. Expected one angle "
263
- "per micrograph."
211
+ "per tilt."
264
212
  )
265
- ctf.angles = wedge.angles
266
- ctf.no_reconstruction = False
267
- ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
268
- else:
269
- needs_reconstruction = False
270
- ctf = CTF(
271
- defocus_x=[args.defocus],
272
- phase_shift=[args.phase_shift],
273
- defocus_y=None,
274
- angles=[0],
275
- shape=None,
276
- return_real_fourier=True,
277
- )
213
+
214
+ except (FileNotFoundError, AttributeError):
215
+ ctf = CTFReconstructed(defocus_x=args.defocus, phase_shift=args.phase_shift)
216
+
217
+ ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
278
218
  ctf.sampling_rate = template.sampling_rate
279
219
  ctf.flip_phase = args.no_flip_phase
280
220
  ctf.amplitude_contrast = args.amplitude_contrast
281
221
  ctf.spherical_aberration = args.spherical_aberration
282
222
  ctf.acceleration_voltage = args.acceleration_voltage * 1e3
283
223
  ctf.correct_defocus_gradient = args.correct_defocus_gradient
284
-
285
- if not needs_reconstruction:
286
- template_filter.append(ctf)
287
- elif isinstance(template_filter[-1], ReconstructFromTilt):
288
- template_filter.insert(-1, ctf)
289
- else:
290
- template_filter.insert(0, ctf)
291
- template_filter.insert(
292
- 1,
293
- ReconstructFromTilt(
294
- reconstruction_filter=args.reconstruction_filter,
295
- interpolation_order=args.reconstruction_interpolation_order,
296
- ),
297
- )
224
+ template_filter.append(ctf)
298
225
 
299
226
  if args.lowpass or args.highpass is not None:
300
227
  lowpass, highpass = args.lowpass, args.highpass
@@ -329,11 +256,31 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
329
256
  template_filter.append(whitening_filter)
330
257
  target_filter.append(whitening_filter)
331
258
 
332
- if needs_reconstruction and args.reconstruction_filter is None:
259
+ rec_filt = (Wedge, CTF)
260
+ needs_reconstruction = sum(type(x) in rec_filt for x in template_filter)
261
+ if needs_reconstruction > 0 and args.reconstruction_filter is None:
333
262
  warnings.warn(
334
- "Consider using a --reconstruction_filter such as 'ramp' to avoid artifacts."
263
+ "Consider using a --reconstruction_filter such as 'ram-lak' or 'ramp' "
264
+ "to avoid artifacts from reconstruction using weighted backprojection."
335
265
  )
336
266
 
267
+ template_filter = sorted(
268
+ template_filter, key=lambda x: type(x) in rec_filt, reverse=True
269
+ )
270
+ if needs_reconstruction > 0:
271
+ relevant_filters = [x for x in template_filter if type(x) in rec_filt]
272
+ if len(relevant_filters) == 0:
273
+ raise ValueError("Filters require ")
274
+
275
+ reconstruction_filter = ReconstructFromTilt(
276
+ reconstruction_filter=args.reconstruction_filter,
277
+ interpolation_order=args.reconstruction_interpolation_order,
278
+ angles=relevant_filters[0].angles,
279
+ opening_axis=args.wedge_axes[0],
280
+ tilt_axis=args.wedge_axes[1],
281
+ )
282
+ template_filter.insert(needs_reconstruction, reconstruction_filter)
283
+
337
284
  template_filter = Compose(template_filter) if len(template_filter) else None
338
285
  target_filter = Compose(target_filter) if len(target_filter) else None
339
286
  if args.no_filter_target:
@@ -342,6 +289,10 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
342
289
  return template_filter, target_filter
343
290
 
344
291
 
292
+ def _format_sampling(arr, decimals: int = 2):
293
+ return tuple(round(float(x), decimals) for x in arr)
294
+
295
+
345
296
  def parse_args():
346
297
  parser = argparse.ArgumentParser(
347
298
  description="Perform template matching.",
@@ -406,6 +357,40 @@ def parse_args():
406
357
  help="Phase scramble the template to generate a noise score background.",
407
358
  )
408
359
 
360
+ sampling_group = parser.add_argument_group("Sampling")
361
+ sampling_group.add_argument(
362
+ "--orientations",
363
+ dest="orientations",
364
+ default=None,
365
+ required=False,
366
+ help="Path to a file readable via Orientations.from_file containing "
367
+ "translations and rotations of candidate peaks to refine.",
368
+ )
369
+ sampling_group.add_argument(
370
+ "--orientations_scaling",
371
+ required=False,
372
+ type=float,
373
+ default=1.0,
374
+ help="Scaling factor to map candidate translations onto the target. "
375
+ "Assuming coordinates are in Å and target sampling rate are 3Å/voxel, "
376
+ "the corresponding orientations_scaling would be 3.",
377
+ )
378
+ sampling_group.add_argument(
379
+ "--orientations_cone",
380
+ required=False,
381
+ type=float,
382
+ default=20.0,
383
+ help="Accept orientations within specified cone angle of each orientation.",
384
+ )
385
+ sampling_group.add_argument(
386
+ "--orientations_uncertainty",
387
+ required=False,
388
+ type=str,
389
+ default="10",
390
+ help="Accept translations within the specified radius of each orientation. "
391
+ "Can be a single value or comma-separated string for per-axis uncertainty.",
392
+ )
393
+
409
394
  scoring_group = parser.add_argument_group("Scoring")
410
395
  scoring_group.add_argument(
411
396
  "-s",
@@ -435,6 +420,13 @@ def parse_args():
435
420
  help="Half-angle of the cone to be sampled in degrees. Allows to sample a "
436
421
  "narrow interval around a known orientation, e.g. for surface oversampling.",
437
422
  )
423
+ angular_exclusive.add_argument(
424
+ "--particle_diameter",
425
+ dest="particle_diameter",
426
+ type=check_positive,
427
+ default=None,
428
+ help="Particle diameter in units of sampling rate.",
429
+ )
438
430
  angular_group.add_argument(
439
431
  "--cone_axis",
440
432
  dest="cone_axis",
@@ -517,6 +509,7 @@ def parse_args():
517
509
  computation_group.add_argument(
518
510
  "-r",
519
511
  "--ram",
512
+ "--memory",
520
513
  dest="memory",
521
514
  required=False,
522
515
  type=int,
@@ -529,7 +522,7 @@ def parse_args():
529
522
  required=False,
530
523
  type=float,
531
524
  default=0.85,
532
- help="Fraction of available memory to be used. Ignored if --ram is set.",
525
+ help="Fraction of available memory to be used. Ignored if --memory is set.",
533
526
  )
534
527
  computation_group.add_argument(
535
528
  "--temp_directory",
@@ -540,7 +533,7 @@ def parse_args():
540
533
  computation_group.add_argument(
541
534
  "--backend",
542
535
  dest="backend",
543
- default=None,
536
+ default=be._backend_name,
544
537
  choices=be.available_backends(),
545
538
  help="[Expert] Overwrite default computation backend.",
546
539
  )
@@ -575,7 +568,8 @@ def parse_args():
575
568
  required=False,
576
569
  default="sampling_rate",
577
570
  choices=["sampling_rate", "voxel", "frequency"],
578
- help="How values passed to --lowpass and --highpass should be interpreted. ",
571
+ help="How values passed to --lowpass and --highpass should be interpreted. "
572
+ "Defaults to unit of sampling_rate, e.g., 40 Angstrom.",
579
573
  )
580
574
  filter_group.add_argument(
581
575
  "--whiten_spectrum",
@@ -589,9 +583,9 @@ def parse_args():
589
583
  dest="wedge_axes",
590
584
  type=str,
591
585
  required=False,
592
- default=None,
593
- help="Indices of wedge opening and tilt axis, e.g. '2,0' for a wedge open "
594
- "in z and tilted over the x-axis.",
586
+ default="2,0",
587
+ help="Indices of projection (wedge opening) and tilt axis, e.g., '2,0' "
588
+ "for the typical projection over z and tilting over the x-axis.",
595
589
  )
596
590
  filter_group.add_argument(
597
591
  "--tilt_angles",
@@ -599,10 +593,12 @@ def parse_args():
599
593
  type=str,
600
594
  required=False,
601
595
  default=None,
602
- help="Path to a tab-separated file containing the column angles and optionally "
603
- " weights, or comma separated start and stop stage tilt angle, e.g. 50,45, which "
604
- " yields a continuous wedge mask. Alternatively, a tilt step size can be "
605
- "specified like 50,45:5.0 to sample 5.0 degree tilt angle steps.",
596
+ help="Path to a file specifying tilt angles. This can be a Warp/M XML file, "
597
+ "a tomostar STAR file, a tab-separated file with column name 'angles', or a "
598
+ "single column file without header. Exposure will be taken from the input file "
599
+ ", if you are using a tab-separated file, the column names 'angles' and "
600
+ "'weights' need to be present. It is also possible to specify a continuous "
601
+ "wedge mask using e.g., -50,45.",
606
602
  )
607
603
  filter_group.add_argument(
608
604
  "--tilt_weighting",
@@ -649,8 +645,9 @@ def parse_args():
649
645
  type=str,
650
646
  required=False,
651
647
  default=None,
652
- help="Path to a file with CTF parameters from CTFFIND4. Each line will be "
653
- "interpreted as tilt obtained at the angle specified in --tilt_angles. ",
648
+ help="Path to a file with CTF parameters. This can be a Warp/M XML file "
649
+ "a GCTF/Relion STAR file, or the output of CTFFIND4. If the file does not "
650
+ "specify tilt angles, the angles specified with --tilt_angles are used.",
654
651
  )
655
652
  ctf_group.add_argument(
656
653
  "--defocus",
@@ -658,8 +655,8 @@ def parse_args():
658
655
  type=float,
659
656
  required=False,
660
657
  default=None,
661
- help="Defocus in units of sampling rate (typically Ångstrom). "
662
- "Superseded by --ctf_file.",
658
+ help="Defocus in units of sampling rate (typically Ångstrom), e.g., 30000 "
659
+ "for a defocus of 3 micrometer. Superseded by --ctf_file.",
663
660
  )
664
661
  ctf_group.add_argument(
665
662
  "--phase_shift",
@@ -745,7 +742,8 @@ def parse_args():
745
742
  dest="use_mixed_precision",
746
743
  action="store_true",
747
744
  default=False,
748
- help="Use float16 for real values operations where possible.",
745
+ help="Use float16 for real values operations where possible. Not supported "
746
+ "for jax backend.",
749
747
  )
750
748
  performance_group.add_argument(
751
749
  "--use_memmap",
@@ -773,8 +771,8 @@ def parse_args():
773
771
  help="Perform peak calling instead of score aggregation.",
774
772
  )
775
773
  analyzer_group.add_argument(
776
- "--number_of_peaks",
777
- dest="number_of_peaks",
774
+ "--num_peaks",
775
+ dest="num_peaks",
778
776
  action="store_true",
779
777
  default=1000,
780
778
  help="Number of peaks to call, 1000 by default.",
@@ -794,21 +792,14 @@ def parse_args():
794
792
  f"score has to be one of {', '.join(MATCHING_EXHAUSTIVE_REGISTER.keys())}"
795
793
  )
796
794
 
797
- gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
798
795
  if args.gpu_indices is not None:
799
796
  os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_indices
800
797
 
801
798
  if args.use_gpu:
802
- gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
803
- if gpu_devices is None:
804
- print(
805
- "No GPU indices provided and CUDA_VISIBLE_DEVICES is not set.",
806
- "Assuming device 0.",
807
- )
808
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
809
- args.gpu_indices = [
810
- int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
811
- ]
799
+ warnings.warn(
800
+ "The use_gpu flag is no longer required and automatically "
801
+ "determined based on the selected backend."
802
+ )
812
803
 
813
804
  if args.tilt_angles is not None:
814
805
  if args.wedge_axes is None:
@@ -825,6 +816,13 @@ def parse_args():
825
816
  if args.wedge_axes is not None:
826
817
  args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
827
818
 
819
+ if args.orientations is not None:
820
+ orientations = Orientations.from_file(args.orientations)
821
+ orientations.translations = np.divide(
822
+ orientations.translations, args.orientations_scaling
823
+ )
824
+ args.orientations = orientations
825
+
828
826
  return args
829
827
 
830
828
 
@@ -864,7 +862,7 @@ def main():
864
862
  name="Target",
865
863
  data={
866
864
  "Initial Shape": initial_shape,
867
- "Sampling Rate": tuple(np.round(target.sampling_rate, 2)),
865
+ "Sampling Rate": _format_sampling(target.sampling_rate),
868
866
  "Final Shape": target.shape,
869
867
  },
870
868
  )
@@ -874,7 +872,7 @@ def main():
874
872
  name="Target Mask",
875
873
  data={
876
874
  "Initial Shape": initial_shape,
877
- "Sampling Rate": tuple(np.round(target_mask.sampling_rate, 2)),
875
+ "Sampling Rate": _format_sampling(target_mask.sampling_rate),
878
876
  "Final Shape": target_mask.shape,
879
877
  },
880
878
  )
@@ -887,7 +885,7 @@ def main():
887
885
  name="Template",
888
886
  data={
889
887
  "Initial Shape": initial_shape,
890
- "Sampling Rate": tuple(np.round(template.sampling_rate, 2)),
888
+ "Sampling Rate": _format_sampling(template.sampling_rate),
891
889
  "Final Shape": template.shape,
892
890
  },
893
891
  )
@@ -919,7 +917,7 @@ def main():
919
917
  name="Template Mask",
920
918
  data={
921
919
  "Inital Shape": initial_shape,
922
- "Sampling Rate": tuple(np.round(template_mask.sampling_rate, 2)),
920
+ "Sampling Rate": _format_sampling(template_mask.sampling_rate),
923
921
  "Final Shape": template_mask.shape,
924
922
  },
925
923
  )
@@ -930,65 +928,71 @@ def main():
930
928
  template.data, noise_proportion=1.0, normalize_power=False
931
929
  )
932
930
 
933
- # Determine suitable backend for the selected operation
934
- available_backends = be.available_backends()
935
- if args.backend is not None:
936
- req_backend = args.backend
937
- if req_backend not in available_backends:
938
- raise ValueError("Requested backend is not available.")
939
- available_backends = [req_backend]
940
-
941
- be_selection = ("numpyfftw", "pytorch", "jax", "mlx")
942
- if args.use_gpu:
943
- args.cores = len(args.gpu_indices)
944
- be_selection = ("pytorch", "cupy", "jax")
945
- if args.use_mixed_precision:
946
- be_selection = tuple(x for x in be_selection if x in ("cupy", "numpyfftw"))
947
-
948
- available_backends = [x for x in available_backends if x in be_selection]
931
+ callback_class = MaxScoreOverRotations
949
932
  if args.peak_calling:
950
- if "jax" in available_backends:
951
- available_backends.remove("jax")
952
- if args.use_gpu and "pytorch" in available_backends:
953
- available_backends = ("pytorch",)
933
+ callback_class = PeakCallerMaximumFilter
954
934
 
955
- # dim_match = len(template.shape) == len(target.shape) <= 3
956
- # if dim_match and args.use_gpu and "jax" in available_backends:
957
- # args.interpolation_order = 1
958
- # available_backends = ["jax"]
935
+ if args.orientations is not None:
936
+ callback_class = MaxScoreOverRotationsConstrained
959
937
 
960
- backend_preference = ("numpyfftw", "pytorch", "jax", "mlx")
961
- if args.use_gpu:
962
- backend_preference = ("cupy", "pytorch", "jax")
963
- for pref in backend_preference:
964
- if pref not in available_backends:
965
- continue
966
- be.change_backend(pref)
967
- if pref == "pytorch":
968
- be.change_backend(pref, device="cuda" if args.use_gpu else "cpu")
969
-
970
- if args.use_mixed_precision:
971
- be.change_backend(
972
- backend_name=pref,
973
- default_dtype=be._array_backend.float16,
974
- complex_dtype=be._array_backend.complex64,
975
- default_dtype_int=be._array_backend.int16,
976
- )
977
- break
938
+ # Determine suitable backend for the selected operation
939
+ available_backends = be.available_backends()
940
+ if args.backend not in available_backends:
941
+ raise ValueError("Requested backend is not available.")
942
+ if args.backend == "jax" and callback_class != MaxScoreOverRotations:
943
+ raise ValueError(
944
+ "Jax backend only supports the MaxScoreOverRotations analyzer."
945
+ )
978
946
 
979
- if pref == "pytorch" and args.interpolation_order == 3:
947
+ if args.interpolation_order == 3 and args.backend in ("jax", "pytorch"):
980
948
  warnings.warn(
981
- "Pytorch does not support --interpolation_order 3, setting it to 1."
949
+ "Jax and pytorch do not support interpolation order 3, setting it to 1."
982
950
  )
983
951
  args.interpolation_order = 1
984
952
 
953
+ if args.backend in ("pytorch", "cupy", "jax"):
954
+ gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
955
+ if gpu_devices is None:
956
+ warnings.warn(
957
+ "No GPU indices provided and CUDA_VISIBLE_DEVICES is not set. "
958
+ "Assuming device 0.",
959
+ )
960
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
961
+ else:
962
+ args.cores = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
963
+ args.gpu_indices = [
964
+ int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
965
+ ]
966
+
967
+ # Finally set the desired backend
968
+ device = "cuda"
969
+ be.change_backend(args.backend)
970
+ if args.backend == "pytorch":
971
+ try:
972
+ be.change_backend("pytorch", device=device)
973
+ # Trigger exception if not compiled with device
974
+ be.get_available_memory()
975
+ except Exception as e:
976
+ print(e)
977
+ device = "cpu"
978
+ be.change_backend("pytorch", device=device)
979
+ if args.use_mixed_precision:
980
+ be.change_backend(
981
+ backend_name=args.backend,
982
+ default_dtype=be._array_backend.float16,
983
+ complex_dtype=be._array_backend.complex64,
984
+ default_dtype_int=be._array_backend.int16,
985
+ device=device,
986
+ )
987
+
985
988
  available_memory = be.get_available_memory() * be.device_count()
986
989
  if args.memory is None:
987
990
  args.memory = int(args.memory_scaling * available_memory)
988
991
 
989
- callback_class = MaxScoreOverRotations
990
- if args.peak_calling:
991
- callback_class = PeakCallerMaximumFilter
992
+ if args.orientations_uncertainty is not None:
993
+ args.orientations_uncertainty = tuple(
994
+ int(x) for x in args.orientations_uncertainty.split(",")
995
+ )
992
996
 
993
997
  matching_data = MatchingData(
994
998
  target=target,
@@ -1079,10 +1083,19 @@ def main():
1079
1083
 
1080
1084
  analyzer_args = {
1081
1085
  "score_threshold": args.score_threshold,
1082
- "number_of_peaks": args.number_of_peaks,
1086
+ "num_peaks": args.num_peaks,
1083
1087
  "min_distance": max(template.shape) // 3,
1084
1088
  "use_memmap": args.use_memmap,
1085
1089
  }
1090
+ if args.orientations is not None:
1091
+ analyzer_args["reference"] = (0, 0, 1)
1092
+ analyzer_args["cone_angle"] = args.orientations_cone
1093
+ analyzer_args["acceptance_radius"] = args.orientations_uncertainty
1094
+ analyzer_args["positions"] = args.orientations.translations
1095
+ analyzer_args["rotations"] = euler_to_rotationmatrix(
1096
+ args.orientations.rotations
1097
+ )
1098
+
1086
1099
  print_block(
1087
1100
  name="Analyzer",
1088
1101
  data={"Analyzer": callback_class, **analyzer_args},
@@ -1111,18 +1124,6 @@ def main():
1111
1124
  )
1112
1125
 
1113
1126
  candidates = list(candidates) if candidates is not None else []
1114
- if issubclass(callback_class, MaxScoreOverRotations):
1115
- if target_mask is not None and args.score != "MCC":
1116
- candidates[0] *= target_mask.data
1117
- with warnings.catch_warnings():
1118
- warnings.simplefilter("ignore", category=UserWarning)
1119
- nbytes = be.datatype_bytes(be._float_dtype)
1120
- dtype = np.float32 if nbytes == 4 else np.float16
1121
- rot_dim = matching_data.rotations.shape[1]
1122
- candidates[3] = {
1123
- x: np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
1124
- for i, x in candidates[3].items()
1125
- }
1126
1127
  candidates.append((target.origin, template.origin, template.sampling_rate, args))
1127
1128
  write_pickle(data=candidates, filename=args.output)
1128
1129