pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3.0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. pytme-0.3.0.data/scripts/estimate_memory_usage.py +76 -0
  2. pytme-0.3.0.data/scripts/match_template.py +1106 -0
  3. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/postprocess.py +320 -190
  4. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocess.py +21 -31
  5. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocessor_gui.py +85 -19
  6. pytme-0.3.0.data/scripts/pytme_runner.py +771 -0
  7. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/METADATA +22 -20
  8. pytme-0.3.0.dist-info/RECORD +126 -0
  9. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/entry_points.txt +2 -1
  10. pytme-0.3.0.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +76 -0
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +224 -0
  14. scripts/match_template.py +349 -378
  15. pytme-0.2.9.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
  16. scripts/postprocess.py +320 -190
  17. scripts/preprocess.py +21 -31
  18. scripts/preprocessor_gui.py +85 -19
  19. scripts/pytme_runner.py +771 -0
  20. scripts/refine_matches.py +625 -0
  21. tests/preprocessing/test_frequency_filters.py +28 -14
  22. tests/test_analyzer.py +41 -36
  23. tests/test_backends.py +1 -0
  24. tests/test_matching_cli.py +109 -53
  25. tests/test_matching_data.py +5 -5
  26. tests/test_matching_exhaustive.py +1 -2
  27. tests/test_matching_optimization.py +4 -9
  28. tests/test_matching_utils.py +1 -1
  29. tests/test_orientations.py +0 -1
  30. tme/__version__.py +1 -1
  31. tme/analyzer/__init__.py +2 -0
  32. tme/analyzer/_utils.py +26 -21
  33. tme/analyzer/aggregation.py +396 -222
  34. tme/analyzer/base.py +127 -0
  35. tme/analyzer/peaks.py +189 -201
  36. tme/analyzer/proxy.py +123 -0
  37. tme/backends/__init__.py +4 -3
  38. tme/backends/_cupy_utils.py +25 -24
  39. tme/backends/_jax_utils.py +20 -18
  40. tme/backends/cupy_backend.py +13 -26
  41. tme/backends/jax_backend.py +24 -23
  42. tme/backends/matching_backend.py +4 -3
  43. tme/backends/mlx_backend.py +4 -3
  44. tme/backends/npfftw_backend.py +34 -30
  45. tme/backends/pytorch_backend.py +18 -4
  46. tme/cli.py +126 -0
  47. tme/density.py +9 -7
  48. tme/extensions.cpython-311-darwin.so +0 -0
  49. tme/filters/__init__.py +3 -3
  50. tme/filters/_utils.py +36 -10
  51. tme/filters/bandpass.py +229 -188
  52. tme/filters/compose.py +5 -4
  53. tme/filters/ctf.py +516 -254
  54. tme/filters/reconstruction.py +91 -32
  55. tme/filters/wedge.py +196 -135
  56. tme/filters/whitening.py +37 -42
  57. tme/matching_data.py +28 -39
  58. tme/matching_exhaustive.py +31 -27
  59. tme/matching_optimization.py +5 -4
  60. tme/matching_scores.py +25 -15
  61. tme/matching_utils.py +158 -28
  62. tme/memory.py +4 -3
  63. tme/orientations.py +22 -9
  64. tme/parser.py +114 -33
  65. tme/preprocessor.py +6 -5
  66. tme/rotations.py +10 -7
  67. tme/structure.py +4 -3
  68. pytme-0.2.9.data/scripts/estimate_ram_usage.py +0 -97
  69. pytme-0.2.9.dist-info/RECORD +0 -119
  70. pytme-0.2.9.dist-info/licenses/LICENSE +0 -153
  71. scripts/estimate_ram_usage.py +0 -97
  72. tests/data/Maps/.DS_Store +0 -0
  73. tests/data/Structures/.DS_Store +0 -0
  74. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/WHEEL +0 -0
  75. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/top_level.txt +0 -0
scripts/match_template.py CHANGED
@@ -1,9 +1,9 @@
1
1
  #!python3
2
- """ CLI for basic pyTME template matching functions.
2
+ """CLI for basic pyTME template matching functions.
3
3
 
4
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
5
 
6
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
  import os
9
9
  import argparse
@@ -12,67 +12,44 @@ from sys import exit
12
12
  from time import time
13
13
  from typing import Tuple
14
14
  from copy import deepcopy
15
- from os.path import exists
16
15
  from tempfile import gettempdir
16
+ from os.path import exists, abspath
17
17
 
18
18
  import numpy as np
19
19
 
20
20
  from tme.backends import backend as be
21
- from tme import Density, __version__
21
+ from tme import Density, __version__, Orientations
22
22
  from tme.matching_utils import scramble_phases, write_pickle
23
23
  from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
24
24
  from tme.rotations import (
25
25
  get_cone_rotations,
26
26
  get_rotation_matrices,
27
+ euler_to_rotationmatrix,
27
28
  )
28
29
  from tme.matching_data import MatchingData
29
30
  from tme.analyzer import (
30
31
  MaxScoreOverRotations,
31
32
  PeakCallerMaximumFilter,
33
+ MaxScoreOverRotationsConstrained,
32
34
  )
33
35
  from tme.filters import (
34
36
  CTF,
35
37
  Wedge,
36
38
  Compose,
37
- BandPassFilter,
39
+ BandPass,
40
+ CTFReconstructed,
38
41
  WedgeReconstructed,
39
42
  ReconstructFromTilt,
40
43
  LinearWhiteningFilter,
44
+ BandPassReconstructed,
45
+ )
46
+ from tme.cli import (
47
+ get_func_fullname,
48
+ print_block,
49
+ print_entry,
50
+ check_positive,
51
+ sanitize_name,
41
52
  )
42
-
43
-
44
- def get_func_fullname(func) -> str:
45
- """Returns the full name of the given function, including its module."""
46
- return f"<function '{func.__module__}.{func.__name__}'>"
47
-
48
-
49
- def print_block(name: str, data: dict, label_width=20) -> None:
50
- """Prints a formatted block of information."""
51
- print(f"\n> {name}")
52
- for key, value in data.items():
53
- if isinstance(value, np.ndarray):
54
- value = value.shape
55
- formatted_value = str(value)
56
- print(f" - {key + ':':<{label_width}} {formatted_value}")
57
-
58
-
59
- def print_entry() -> None:
60
- width = 80
61
- text = f" pytme v{__version__} "
62
- padding_total = width - len(text) - 2
63
- padding_left = padding_total // 2
64
- padding_right = padding_total - padding_left
65
-
66
- print("*" * width)
67
- print(f"*{ ' ' * padding_left }{text}{ ' ' * padding_right }*")
68
- print("*" * width)
69
-
70
-
71
- def check_positive(value):
72
- ivalue = float(value)
73
- if ivalue <= 0:
74
- raise argparse.ArgumentTypeError("%s is an invalid positive float." % value)
75
- return ivalue
76
53
 
77
54
 
78
55
  def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
@@ -118,6 +95,14 @@ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
118
95
 
119
96
 
120
97
  def parse_rotation_logic(args, ndim):
98
+ if args.particle_diameter is not None:
99
+ resolution = Density.from_file(args.target, use_memmap=True)
100
+ resolution = 360 * np.maximum(
101
+ np.max(2 * resolution.sampling_rate),
102
+ args.lowpass if args.lowpass is not None else 0,
103
+ )
104
+ args.angular_sampling = resolution / (3.14159265358979 * args.particle_diameter)
105
+
121
106
  if args.angular_sampling is not None:
122
107
  rotations = get_rotation_matrices(
123
108
  angular_sampling=args.angular_sampling,
@@ -138,7 +123,7 @@ def parse_rotation_logic(args, ndim):
138
123
  axis_sampling=args.axis_sampling,
139
124
  n_symmetry=args.axis_symmetry,
140
125
  axis=[0 if i != args.cone_axis else 1 for i in range(ndim)],
141
- reference=[0, 0, -1],
126
+ reference=[0, 0, -1 if args.invert_cone else 1],
142
127
  )
143
128
  return rotations
144
129
 
@@ -178,123 +163,82 @@ def compute_schedule(
178
163
 
179
164
 
180
165
  def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
181
- needs_reconstruction = False
182
166
  template_filter, target_filter = [], []
167
+
168
+ if args.tilt_angles is None:
169
+ args.tilt_angles = args.ctf_file
170
+
171
+ wedge = None
183
172
  if args.tilt_angles is not None:
184
- needs_reconstruction = args.tilt_weighting is not None
185
173
  try:
186
174
  wedge = Wedge.from_file(args.tilt_angles)
187
175
  wedge.weight_type = args.tilt_weighting
188
- if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
176
+ if args.tilt_weighting in ("angle", None):
189
177
  wedge = WedgeReconstructed(
190
178
  angles=wedge.angles,
191
179
  weight_wedge=args.tilt_weighting == "angle",
192
- opening_axis=args.wedge_axes[0],
193
- tilt_axis=args.wedge_axes[1],
194
180
  )
195
- except FileNotFoundError:
196
- tilt_step, create_continuous_wedge = None, True
181
+ except (FileNotFoundError, AttributeError):
197
182
  tilt_start, tilt_stop = args.tilt_angles.split(",")
198
- if ":" in tilt_stop:
199
- create_continuous_wedge = False
200
- tilt_stop, tilt_step = tilt_stop.split(":")
201
- tilt_start, tilt_stop = float(tilt_start), float(tilt_stop)
202
- tilt_angles = (tilt_start, tilt_stop)
203
- if tilt_step is not None:
204
- tilt_step = float(tilt_step)
205
- tilt_angles = np.arange(
206
- -tilt_start, tilt_stop + tilt_step, tilt_step
207
- ).tolist()
208
-
209
- if args.tilt_weighting is not None and tilt_step is None:
210
- raise ValueError(
211
- "Tilt weighting is not supported for continuous wedges."
212
- )
213
- if args.tilt_weighting not in ("angle", None):
214
- raise ValueError(
215
- "Tilt weighting schemes other than 'angle' or 'None' require "
216
- "a specification of electron doses via --tilt_angles."
217
- )
218
-
219
- wedge = Wedge(
220
- angles=tilt_angles,
221
- opening_axis=args.wedge_axes[0],
222
- tilt_axis=args.wedge_axes[1],
223
- shape=None,
224
- weight_type=None,
225
- weights=np.ones_like(tilt_angles),
183
+ tilt_start, tilt_stop = abs(float(tilt_start)), abs(float(tilt_stop))
184
+ wedge = WedgeReconstructed(
185
+ angles=(tilt_start, tilt_stop),
186
+ create_continuous_wedge=True,
187
+ weight_wedge=False,
188
+ reconstruction_filter=args.reconstruction_filter,
226
189
  )
227
- if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
228
- wedge = WedgeReconstructed(
229
- angles=tilt_angles,
230
- weight_wedge=args.tilt_weighting == "angle",
231
- create_continuous_wedge=create_continuous_wedge,
232
- reconstruction_filter=args.reconstruction_filter,
233
- opening_axis=args.wedge_axes[0],
234
- tilt_axis=args.wedge_axes[1],
235
- )
236
- wedge_target = WedgeReconstructed(
237
- angles=(np.abs(np.min(tilt_angles)), np.abs(np.max(tilt_angles))),
238
- weight_wedge=False,
239
- create_continuous_wedge=True,
240
- opening_axis=args.wedge_axes[0],
241
- tilt_axis=args.wedge_axes[1],
242
- )
243
- target_filter.append(wedge_target)
190
+ wedge.opening_axis, wedge.tilt_axis = args.wedge_axes
191
+
192
+ wedge_target = WedgeReconstructed(
193
+ angles=wedge.angles,
194
+ weight_wedge=False,
195
+ create_continuous_wedge=True,
196
+ opening_axis=wedge.opening_axis,
197
+ tilt_axis=wedge.tilt_axis,
198
+ )
244
199
 
245
200
  wedge.sampling_rate = template.sampling_rate
201
+ wedge_target.sampling_rate = template.sampling_rate
202
+
203
+ target_filter.append(wedge_target)
246
204
  template_filter.append(wedge)
247
- if not isinstance(wedge, WedgeReconstructed):
248
- reconstruction_filter = ReconstructFromTilt(
249
- reconstruction_filter=args.reconstruction_filter,
250
- interpolation_order=args.reconstruction_interpolation_order,
251
- )
252
- template_filter.append(reconstruction_filter)
253
205
 
254
206
  if args.ctf_file is not None or args.defocus is not None:
255
- needs_reconstruction = True
256
- if args.ctf_file is not None:
257
- ctf = CTF.from_file(args.ctf_file)
207
+ try:
208
+ ctf = CTF.from_file(
209
+ args.ctf_file,
210
+ spherical_aberration=args.spherical_aberration,
211
+ amplitude_contrast=args.amplitude_contrast,
212
+ acceleration_voltage=args.acceleration_voltage * 1e3,
213
+ )
214
+ if (len(ctf.angles) == 0) and wedge is None:
215
+ raise ValueError(
216
+ "You requested to specify the CTF per tilt, but did not specify "
217
+ "tilt angles via --tilt-angles or --ctf-file. "
218
+ )
219
+ if len(ctf.angles) == 0:
220
+ ctf.angles = wedge.angles
221
+
258
222
  n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
259
- if n_tilts_ctfs != n_tils_angles:
223
+ if (n_tilts_ctfs != n_tils_angles) and isinstance(wedge, Wedge):
260
224
  raise ValueError(
261
- f"CTF file contains {n_tilts_ctfs} micrographs, but match_template "
262
- f"recieved {n_tils_angles} tilt angles. Expected one angle "
263
- "per micrograph."
225
+ f"CTF file contains {n_tilts_ctfs} tilt, but recieved "
226
+ f"{n_tils_angles} tilt angles. Expected one angle per tilt"
264
227
  )
265
- ctf.angles = wedge.angles
266
- ctf.no_reconstruction = False
267
- ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
268
- else:
269
- needs_reconstruction = False
270
- ctf = CTF(
271
- defocus_x=[args.defocus],
272
- phase_shift=[args.phase_shift],
273
- defocus_y=None,
274
- angles=[0],
275
- shape=None,
276
- return_real_fourier=True,
228
+
229
+ except (FileNotFoundError, AttributeError):
230
+ ctf = CTFReconstructed(
231
+ defocus_x=args.defocus,
232
+ phase_shift=args.phase_shift,
233
+ amplitude_contrast=args.amplitude_contrast,
234
+ spherical_aberration=args.spherical_aberration,
235
+ acceleration_voltage=args.acceleration_voltage * 1e3,
277
236
  )
278
- ctf.sampling_rate = template.sampling_rate
279
237
  ctf.flip_phase = args.no_flip_phase
280
- ctf.amplitude_contrast = args.amplitude_contrast
281
- ctf.spherical_aberration = args.spherical_aberration
282
- ctf.acceleration_voltage = args.acceleration_voltage * 1e3
238
+ ctf.sampling_rate = template.sampling_rate
239
+ ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
283
240
  ctf.correct_defocus_gradient = args.correct_defocus_gradient
284
-
285
- if not needs_reconstruction:
286
- template_filter.append(ctf)
287
- elif isinstance(template_filter[-1], ReconstructFromTilt):
288
- template_filter.insert(-1, ctf)
289
- else:
290
- template_filter.insert(0, ctf)
291
- template_filter.insert(
292
- 1,
293
- ReconstructFromTilt(
294
- reconstruction_filter=args.reconstruction_filter,
295
- interpolation_order=args.reconstruction_interpolation_order,
296
- ),
297
- )
241
+ template_filter.append(ctf)
298
242
 
299
243
  if args.lowpass or args.highpass is not None:
300
244
  lowpass, highpass = args.lowpass, args.highpass
@@ -315,7 +259,7 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
315
259
  except Exception:
316
260
  pass
317
261
 
318
- bandpass = BandPassFilter(
262
+ bandpass = BandPassReconstructed(
319
263
  use_gaussian=args.no_pass_smooth,
320
264
  lowpass=lowpass,
321
265
  highpass=highpass,
@@ -329,10 +273,30 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
329
273
  template_filter.append(whitening_filter)
330
274
  target_filter.append(whitening_filter)
331
275
 
332
- if needs_reconstruction and args.reconstruction_filter is None:
276
+ rec_filt = (Wedge, CTF)
277
+ needs_reconstruction = sum(type(x) in rec_filt for x in template_filter)
278
+ if needs_reconstruction > 0 and args.reconstruction_filter is None:
333
279
  warnings.warn(
334
- "Consider using a --reconstruction_filter such as 'ramp' to avoid artifacts."
280
+ "Consider using a --reconstruction_filter such as 'ram-lak' or 'ramp' "
281
+ "to avoid artifacts from reconstruction using weighted backprojection."
282
+ )
283
+
284
+ template_filter = sorted(
285
+ template_filter, key=lambda x: type(x) in rec_filt, reverse=True
286
+ )
287
+ if needs_reconstruction > 0:
288
+ relevant_filters = [x for x in template_filter if type(x) in rec_filt]
289
+ if len(relevant_filters) == 0:
290
+ raise ValueError("Filters require ")
291
+
292
+ reconstruction_filter = ReconstructFromTilt(
293
+ reconstruction_filter=args.reconstruction_filter,
294
+ interpolation_order=args.reconstruction_interpolation_order,
295
+ angles=relevant_filters[0].angles,
296
+ opening_axis=args.wedge_axes[0],
297
+ tilt_axis=args.wedge_axes[1],
335
298
  )
299
+ template_filter.insert(needs_reconstruction, reconstruction_filter)
336
300
 
337
301
  template_filter = Compose(template_filter) if len(template_filter) else None
338
302
  target_filter = Compose(target_filter) if len(target_filter) else None
@@ -342,6 +306,10 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
342
306
  return template_filter, target_filter
343
307
 
344
308
 
309
+ def _format_sampling(arr, decimals: int = 2):
310
+ return tuple(round(float(x), decimals) for x in arr)
311
+
312
+
345
313
  def parse_args():
346
314
  parser = argparse.ArgumentParser(
347
315
  description="Perform template matching.",
@@ -352,7 +320,6 @@ def parse_args():
352
320
  io_group.add_argument(
353
321
  "-m",
354
322
  "--target",
355
- dest="target",
356
323
  type=str,
357
324
  required=True,
358
325
  help="Path to a target in CCP4/MRC, EM, H5 or another format supported by "
@@ -360,8 +327,8 @@ def parse_args():
360
327
  "https://kosinskilab.github.io/pyTME/reference/api/tme.density.Density.from_file.html",
361
328
  )
362
329
  io_group.add_argument(
363
- "--target_mask",
364
- dest="target_mask",
330
+ "-M",
331
+ "--target-mask",
365
332
  type=str,
366
333
  required=False,
367
334
  help="Path to a mask for the target in a supported format (see target).",
@@ -369,14 +336,13 @@ def parse_args():
369
336
  io_group.add_argument(
370
337
  "-i",
371
338
  "--template",
372
- dest="template",
373
339
  type=str,
374
340
  required=True,
375
341
  help="Path to a template in PDB/MMCIF or other supported formats (see target).",
376
342
  )
377
343
  io_group.add_argument(
378
- "--template_mask",
379
- dest="template_mask",
344
+ "-I",
345
+ "--template-mask",
380
346
  type=str,
381
347
  required=False,
382
348
  help="Path to a mask for the template in a supported format (see target).",
@@ -384,32 +350,62 @@ def parse_args():
384
350
  io_group.add_argument(
385
351
  "-o",
386
352
  "--output",
387
- dest="output",
388
353
  type=str,
389
354
  required=False,
390
355
  default="output.pickle",
391
356
  help="Path to the output pickle file.",
392
357
  )
393
358
  io_group.add_argument(
394
- "--invert_target_contrast",
395
- dest="invert_target_contrast",
359
+ "--invert-target-contrast",
396
360
  action="store_true",
397
361
  default=False,
398
362
  help="Invert the target's contrast for cases where templates to-be-matched have "
399
363
  "negative values, e.g. tomograms.",
400
364
  )
401
365
  io_group.add_argument(
402
- "--scramble_phases",
403
- dest="scramble_phases",
366
+ "--scramble-phases",
404
367
  action="store_true",
405
368
  default=False,
406
369
  help="Phase scramble the template to generate a noise score background.",
407
370
  )
408
371
 
372
+ sampling_group = parser.add_argument_group("Sampling")
373
+ sampling_group.add_argument(
374
+ "--orientations",
375
+ default=None,
376
+ required=False,
377
+ help="Path to a file readable by Orientations.from_file containing "
378
+ "translations and rotations of candidate peaks to refine.",
379
+ )
380
+ sampling_group.add_argument(
381
+ "--orientations-scaling",
382
+ required=False,
383
+ type=float,
384
+ default=1.0,
385
+ help="Scaling factor to map candidate translations onto the target. "
386
+ "Assuming coordinates are in Å and target sampling rate are 3Å/voxel, "
387
+ "the corresponding --orientations-scaling would be 3.",
388
+ )
389
+ sampling_group.add_argument(
390
+ "--orientations-cone",
391
+ required=False,
392
+ type=float,
393
+ default=20.0,
394
+ help="Accept orientations within specified cone angle of each orientation.",
395
+ )
396
+ sampling_group.add_argument(
397
+ "--orientations-uncertainty",
398
+ required=False,
399
+ type=str,
400
+ default="10",
401
+ help="Accept translations within the specified radius of each orientation. "
402
+ "Can be a single value or comma-separated string for per-axis uncertainty.",
403
+ )
404
+
409
405
  scoring_group = parser.add_argument_group("Scoring")
410
406
  scoring_group.add_argument(
411
407
  "-s",
412
- dest="score",
408
+ "--score",
413
409
  type=str,
414
410
  default="FLCSphericalMask",
415
411
  choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
@@ -421,67 +417,64 @@ def parse_args():
421
417
 
422
418
  angular_exclusive.add_argument(
423
419
  "-a",
424
- dest="angular_sampling",
420
+ "--angular-sampling",
425
421
  type=check_positive,
426
422
  default=None,
427
- help="Angular sampling rate using optimized rotational sets."
428
- "A lower number yields more rotations. Values >= 180 sample only the identity.",
423
+ help="Angular sampling rate. Lower values = more rotations, higher precision.",
429
424
  )
430
425
  angular_exclusive.add_argument(
431
- "--cone_angle",
432
- dest="cone_angle",
426
+ "--cone-angle",
433
427
  type=check_positive,
434
428
  default=None,
435
429
  help="Half-angle of the cone to be sampled in degrees. Allows to sample a "
436
430
  "narrow interval around a known orientation, e.g. for surface oversampling.",
437
431
  )
432
+ angular_exclusive.add_argument(
433
+ "--particle-diameter",
434
+ type=check_positive,
435
+ default=None,
436
+ help="Particle diameter in units of sampling rate.",
437
+ )
438
438
  angular_group.add_argument(
439
- "--cone_axis",
440
- dest="cone_axis",
439
+ "--cone-axis",
441
440
  type=check_positive,
442
441
  default=2,
443
442
  help="Principal axis to build cone around.",
444
443
  )
445
444
  angular_group.add_argument(
446
- "--invert_cone",
447
- dest="invert_cone",
445
+ "--invert-cone",
448
446
  action="store_true",
449
447
  help="Invert cone handedness.",
450
448
  )
451
449
  angular_group.add_argument(
452
- "--cone_sampling",
453
- dest="cone_sampling",
450
+ "--cone-sampling",
454
451
  type=check_positive,
455
452
  default=None,
456
453
  help="Sampling rate of the cone in degrees.",
457
454
  )
458
455
  angular_group.add_argument(
459
- "--axis_angle",
460
- dest="axis_angle",
456
+ "--axis-angle",
461
457
  type=check_positive,
462
458
  default=360.0,
463
459
  required=False,
464
460
  help="Sampling angle along the z-axis of the cone.",
465
461
  )
466
462
  angular_group.add_argument(
467
- "--axis_sampling",
468
- dest="axis_sampling",
463
+ "--axis-sampling",
469
464
  type=check_positive,
470
465
  default=None,
471
466
  required=False,
472
- help="Sampling rate along the z-axis of the cone. Defaults to --cone_sampling.",
467
+ help="Sampling rate along the z-axis of the cone. Defaults to --cone-sampling.",
473
468
  )
474
469
  angular_group.add_argument(
475
- "--axis_symmetry",
476
- dest="axis_symmetry",
470
+ "--axis-symmetry",
477
471
  type=check_positive,
478
472
  default=1,
479
473
  required=False,
480
474
  help="N-fold symmetry around z-axis of the cone.",
481
475
  )
482
476
  angular_group.add_argument(
483
- "--no_use_optimized_set",
484
- dest="no_use_optimized_set",
477
+ "--no-use-optimized-set",
485
478
  action="store_true",
486
479
  default=False,
487
480
  required=False,
@@ -498,56 +491,40 @@ def parse_args():
498
491
  help="Number of cores used for template matching.",
499
492
  )
500
493
  computation_group.add_argument(
501
- "--use_gpu",
502
- dest="use_gpu",
503
- action="store_true",
504
- default=False,
505
- help="Whether to perform computations on the GPU.",
506
- )
507
- computation_group.add_argument(
508
- "--gpu_indices",
509
- dest="gpu_indices",
494
+ "--gpu-indices",
510
495
  type=str,
511
496
  default=None,
512
- help="Comma-separated list of GPU indices to use. For example,"
513
- " 0,1 for the first and second GPU. Only used if --use_gpu is set."
514
- " If not provided but --use_gpu is set, CUDA_VISIBLE_DEVICES will"
515
- " be respected.",
497
+ help="Comma-separated GPU indices (e.g., '0,1,2' for first 3 GPUs). Otherwise "
498
+ "CUDA_VISIBLE_DEVICES will be used.",
516
499
  )
517
500
  computation_group.add_argument(
518
- "-r",
519
- "--ram",
520
- dest="memory",
501
+ "--memory",
521
502
  required=False,
522
503
  type=int,
523
504
  default=None,
524
505
  help="Amount of memory that can be used in bytes.",
525
506
  )
526
507
  computation_group.add_argument(
527
- "--memory_scaling",
528
- dest="memory_scaling",
508
+ "--memory-scaling",
529
509
  required=False,
530
510
  type=float,
531
511
  default=0.85,
532
- help="Fraction of available memory to be used. Ignored if --ram is set.",
512
+ help="Fraction of available memory to be used. Ignored if --memory is set.",
533
513
  )
534
514
  computation_group.add_argument(
535
- "--temp_directory",
536
- dest="temp_directory",
537
- default=None,
538
- help="Directory for temporary objects. Faster I/O improves runtime.",
515
+ "--temp-directory",
516
+ default=gettempdir(),
517
+ help="Temporary directory for memmaps. Better I/O improves runtime.",
539
518
  )
540
519
  computation_group.add_argument(
541
520
  "--backend",
542
- dest="backend",
543
- default=None,
521
+ default=be._backend_name,
544
522
  choices=be.available_backends(),
545
- help="[Expert] Overwrite default computation backend.",
523
+ help="Set computation backend.",
546
524
  )
547
525
  filter_group = parser.add_argument_group("Filters")
548
526
  filter_group.add_argument(
549
527
  "--lowpass",
550
- dest="lowpass",
551
528
  type=float,
552
529
  required=False,
553
530
  help="Resolution to lowpass filter template and target to in the same unit "
@@ -555,58 +532,54 @@ def parse_args():
555
532
  )
556
533
  filter_group.add_argument(
557
534
  "--highpass",
558
- dest="highpass",
559
535
  type=float,
560
536
  required=False,
561
537
  help="Resolution to highpass filter template and target to in the same unit "
562
538
  "as the sampling rate of template and target (typically Ångstrom).",
563
539
  )
564
540
  filter_group.add_argument(
565
- "--no_pass_smooth",
566
- dest="no_pass_smooth",
541
+ "--no-pass-smooth",
567
542
  action="store_false",
568
543
  default=True,
569
544
  help="Whether a hard edge filter should be used for --lowpass and --highpass.",
570
545
  )
571
546
  filter_group.add_argument(
572
- "--pass_format",
573
- dest="pass_format",
547
+ "--pass-format",
574
548
  type=str,
575
549
  required=False,
576
550
  default="sampling_rate",
577
551
  choices=["sampling_rate", "voxel", "frequency"],
578
- help="How values passed to --lowpass and --highpass should be interpreted. ",
552
+ help="How values passed to --lowpass and --highpass should be interpreted. "
553
+ "Defaults to unit of sampling_rate, e.g., 40 Angstrom.",
579
554
  )
580
555
  filter_group.add_argument(
581
- "--whiten_spectrum",
582
- dest="whiten_spectrum",
556
+ "--whiten-spectrum",
583
557
  action="store_true",
584
558
  default=None,
585
559
  help="Apply spectral whitening to template and target based on target spectrum.",
586
560
  )
587
561
  filter_group.add_argument(
588
- "--wedge_axes",
589
- dest="wedge_axes",
562
+ "--wedge-axes",
590
563
  type=str,
591
564
  required=False,
592
- default=None,
593
- help="Indices of wedge opening and tilt axis, e.g. '2,0' for a wedge open "
594
- "in z and tilted over the x-axis.",
565
+ default="2,0",
566
+ help="Indices of projection (wedge opening) and tilt axis, e.g., '2,0' "
567
+ "for the typical projection over z and tilting over the x-axis.",
595
568
  )
596
569
  filter_group.add_argument(
597
- "--tilt_angles",
598
- dest="tilt_angles",
570
+ "--tilt-angles",
599
571
  type=str,
600
572
  required=False,
601
573
  default=None,
602
- help="Path to a tab-separated file containing the column angles and optionally "
603
- " weights, or comma separated start and stop stage tilt angle, e.g. 50,45, which "
604
- " yields a continuous wedge mask. Alternatively, a tilt step size can be "
605
- "specified like 50,45:5.0 to sample 5.0 degree tilt angle steps.",
574
+ help="Path to a file specifying tilt angles. This can be a Warp/M XML file, "
575
+ "a tomostar STAR file, an MMOD file, a tab-separated file with column name "
576
+ "'angles', or a single column file without header. Exposure will be taken from "
577
+ "the input file , if you are using a tab-separated file, the column names "
578
+ "'angles' and 'weights' need to be present. It is also possible to specify a "
579
+ "continuous wedge mask using e.g., -50,45.",
606
580
  )
607
581
  filter_group.add_argument(
608
- "--tilt_weighting",
609
- dest="tilt_weighting",
582
+ "--tilt-weighting",
610
583
  type=str,
611
584
  required=False,
612
585
  choices=["angle", "relion", "grigorieff"],
@@ -615,28 +588,25 @@ def parse_args():
615
588
  "angle (cosine based weighting), "
616
589
  "relion (relion formalism for wedge weighting) requires,"
617
590
  "grigorieff (exposure filter as defined in Grant and Grigorieff 2015)."
618
- "relion and grigorieff require electron doses in --tilt_angles weights column.",
591
+ "relion and grigorieff require electron doses in --tilt-angles weights column.",
619
592
  )
620
593
  filter_group.add_argument(
621
- "--reconstruction_filter",
622
- dest="reconstruction_filter",
594
+ "--reconstruction-filter",
623
595
  type=str,
624
596
  required=False,
625
597
  choices=["ram-lak", "ramp", "ramp-cont", "shepp-logan", "cosine", "hamming"],
626
- default=None,
598
+ default="ramp",
627
599
  help="Filter applied when reconstructing (N+1)-D from N-D filters.",
628
600
  )
629
601
  filter_group.add_argument(
630
- "--reconstruction_interpolation_order",
631
- dest="reconstruction_interpolation_order",
602
+ "--reconstruction-interpolation-order",
632
603
  type=int,
633
604
  default=1,
634
605
  required=False,
635
- help="Analogous to --interpolation_order but for reconstruction.",
606
+ help="Analogous to --interpolation-order but for reconstruction.",
636
607
  )
637
608
  filter_group.add_argument(
638
- "--no_filter_target",
639
- dest="no_filter_target",
609
+ "--no-filter-target",
640
610
  action="store_true",
641
611
  default=False,
642
612
  help="Whether to not apply potential filters to the target.",
@@ -644,65 +614,58 @@ def parse_args():
644
614
 
645
615
  ctf_group = parser.add_argument_group("Contrast Transfer Function")
646
616
  ctf_group.add_argument(
647
- "--ctf_file",
648
- dest="ctf_file",
617
+ "--ctf-file",
649
618
  type=str,
650
619
  required=False,
651
620
  default=None,
652
- help="Path to a file with CTF parameters from CTFFIND4. Each line will be "
653
- "interpreted as tilt obtained at the angle specified in --tilt_angles. ",
621
+ help="Path to a file with CTF parameters. This can be a Warp/M XML file "
622
+ "a GCTF/Relion STAR file, an MDOC file, or the output of CTFFIND4. If the file "
623
+ " does not specify tilt angles, --tilt-angles are used.",
654
624
  )
655
625
  ctf_group.add_argument(
656
626
  "--defocus",
657
- dest="defocus",
658
627
  type=float,
659
628
  required=False,
660
629
  default=None,
661
- help="Defocus in units of sampling rate (typically Ångstrom). "
662
- "Superseded by --ctf_file.",
630
+ help="Defocus in units of sampling rate (typically Ångstrom), e.g., 30000 "
631
+ "for a defocus of 3 micrometer. Superseded by --ctf-file.",
663
632
  )
664
633
  ctf_group.add_argument(
665
- "--phase_shift",
666
- dest="phase_shift",
634
+ "--phase-shift",
667
635
  type=float,
668
636
  required=False,
669
637
  default=0,
670
- help="Phase shift in degrees. Superseded by --ctf_file.",
638
+ help="Phase shift in degrees. Superseded by --ctf-file.",
671
639
  )
672
640
  ctf_group.add_argument(
673
- "--acceleration_voltage",
674
- dest="acceleration_voltage",
641
+ "--acceleration-voltage",
675
642
  type=float,
676
643
  required=False,
677
644
  default=300,
678
645
  help="Acceleration voltage in kV.",
679
646
  )
680
647
  ctf_group.add_argument(
681
- "--spherical_aberration",
682
- dest="spherical_aberration",
648
+ "--spherical-aberration",
683
649
  type=float,
684
650
  required=False,
685
651
  default=2.7e7,
686
652
  help="Spherical aberration in units of sampling rate (typically Ångstrom).",
687
653
  )
688
654
  ctf_group.add_argument(
689
- "--amplitude_contrast",
690
- dest="amplitude_contrast",
655
+ "--amplitude-contrast",
691
656
  type=float,
692
657
  required=False,
693
658
  default=0.07,
694
659
  help="Amplitude contrast.",
695
660
  )
696
661
  ctf_group.add_argument(
697
- "--no_flip_phase",
698
- dest="no_flip_phase",
662
+ "--no-flip-phase",
699
663
  action="store_false",
700
664
  required=False,
701
665
  help="Do not perform phase-flipping CTF correction.",
702
666
  )
703
667
  ctf_group.add_argument(
704
- "--correct_defocus_gradient",
705
- dest="correct_defocus_gradient",
668
+ "--correct-defocus-gradient",
706
669
  action="store_true",
707
670
  required=False,
708
671
  help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
@@ -711,55 +674,49 @@ def parse_args():
711
674
 
712
675
  performance_group = parser.add_argument_group("Performance")
713
676
  performance_group.add_argument(
714
- "--no_centering",
715
- dest="no_centering",
677
+ "--centering",
716
678
  action="store_true",
717
- help="Assumes the template is already centered and omits centering.",
679
+ help="Center the template in the box if it has not been done already.",
718
680
  )
719
681
  performance_group.add_argument(
720
- "--pad_edges",
721
- dest="pad_edges",
682
+ "--pad-edges",
722
683
  action="store_true",
723
684
  default=False,
724
- help="Whether to pad the edges of the target. Useful if the target does not "
725
- "a well-defined bounding box. Defaults to True if splitting is required.",
685
+ help="Useful if the target does not have a well-defined bounding box. Will be "
686
+ "activated automatically if splitting is required to avoid boundary artifacts.",
726
687
  )
727
688
  performance_group.add_argument(
728
- "--pad_filter",
729
- dest="pad_filter",
689
+ "--pad-filter",
730
690
  action="store_true",
731
691
  default=False,
732
- help="Pads the filter to the shape of the target. Particularly useful for fast "
692
+ help="Pad the template filter to the shape of the target. Useful for fast "
733
693
  "oscilating filters to avoid aliasing effects.",
734
694
  )
735
695
  performance_group.add_argument(
736
- "--interpolation_order",
737
- dest="interpolation_order",
696
+ "--interpolation-order",
738
697
  required=False,
739
698
  type=int,
740
- default=3,
741
- help="Spline interpolation used for rotations.",
699
+ default=None,
700
+ help="Spline interpolation used for rotations. Defaults to 3, and 1 for jax "
701
+ "and pytorch backends.",
742
702
  )
743
703
  performance_group.add_argument(
744
- "--use_mixed_precision",
745
- dest="use_mixed_precision",
704
+ "--use-mixed-precision",
746
705
  action="store_true",
747
706
  default=False,
748
- help="Use float16 for real values operations where possible.",
707
+ help="Use float16 for real values operations where possible. Not supported "
708
+ "for jax backend.",
749
709
  )
750
710
  performance_group.add_argument(
751
- "--use_memmap",
752
- dest="use_memmap",
711
+ "--use-memmap",
753
712
  action="store_true",
754
713
  default=False,
755
- help="Use memmaps to offload large data objects to disk. "
756
- "Particularly useful for large inputs in combination with --use_gpu.",
714
+ help="Memmap large data to disk, e.g., matching on unbinned tomograms.",
757
715
  )
758
716
 
759
- analyzer_group = parser.add_argument_group("Analyzer")
717
+ analyzer_group = parser.add_argument_group("Output / Analysis")
760
718
  analyzer_group.add_argument(
761
- "--score_threshold",
762
- dest="score_threshold",
719
+ "--score-threshold",
763
720
  required=False,
764
721
  type=float,
765
722
  default=0,
@@ -767,21 +724,25 @@ def parse_args():
767
724
  )
768
725
  analyzer_group.add_argument(
769
726
  "-p",
770
- dest="peak_calling",
727
+ "--peak-calling",
771
728
  action="store_true",
772
729
  default=False,
773
730
  help="Perform peak calling instead of score aggregation.",
774
731
  )
775
732
  analyzer_group.add_argument(
776
- "--number_of_peaks",
777
- dest="number_of_peaks",
778
- action="store_true",
733
+ "--num-peaks",
734
+ type=int,
779
735
  default=1000,
780
736
  help="Number of peaks to call, 1000 by default.",
781
737
  )
782
738
  args = parser.parse_args()
783
739
  args.version = __version__
784
740
 
741
+ if args.interpolation_order is None:
742
+ args.interpolation_order = 3
743
+ if args.backend in ("jax", "pytorch"):
744
+ args.interpolation_order = 1
745
+
785
746
  if args.interpolation_order < 0:
786
747
  args.interpolation_order = None
787
748
 
@@ -789,41 +750,42 @@ def parse_args():
789
750
  args.temp_directory = gettempdir()
790
751
 
791
752
  os.environ["TMPDIR"] = args.temp_directory
792
- if args.score not in MATCHING_EXHAUSTIVE_REGISTER:
793
- raise ValueError(
794
- f"score has to be one of {', '.join(MATCHING_EXHAUSTIVE_REGISTER.keys())}"
795
- )
796
-
797
- gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
798
753
  if args.gpu_indices is not None:
799
754
  os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_indices
800
755
 
801
- if args.use_gpu:
802
- gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
803
- if gpu_devices is None:
804
- print(
805
- "No GPU indices provided and CUDA_VISIBLE_DEVICES is not set.",
806
- "Assuming device 0.",
756
+ if args.tilt_angles is not None and not exists(args.tilt_angles):
757
+ try:
758
+ float(args.tilt_angles.split(",")[0])
759
+ except Exception:
760
+ raise ValueError(f"{args.tilt_angles} is not a file nor a range.")
761
+
762
+ if args.ctf_file is not None and args.tilt_angles is None:
763
+ # Check if tilt angles can be extracted from CTF specification
764
+ try:
765
+ ctf = CTF.from_file(args.ctf_file)
766
+ if ctf.angles is None:
767
+ raise ValueError
768
+ args.tilt_angles = args.ctf_file
769
+ except Exception:
770
+ raise ValueError(
771
+ "Need to specify --tilt-angles when not provided in --ctf-file."
807
772
  )
808
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
809
- args.gpu_indices = [
810
- int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
811
- ]
812
773
 
813
- if args.tilt_angles is not None:
814
- if args.wedge_axes is None:
815
- raise ValueError("Need to specify --wedge_axes when --tilt_angles is set.")
816
- if not exists(args.tilt_angles):
817
- try:
818
- float(args.tilt_angles.split(",")[0])
819
- except ValueError:
820
- raise ValueError(f"{args.tilt_angles} is not a file nor a range.")
774
+ args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
775
+ if args.orientations is not None:
776
+ orientations = Orientations.from_file(args.orientations)
777
+ orientations.translations = np.divide(
778
+ orientations.translations, args.orientations_scaling
779
+ )
780
+ args.orientations = orientations
821
781
 
822
- if args.ctf_file is not None and args.tilt_angles is None:
823
- raise ValueError("Need to specify --tilt_angles when --ctf_file is set.")
782
+ args.target = abspath(args.target)
783
+ if args.target_mask is not None:
784
+ args.target_mask = abspath(args.target_mask)
824
785
 
825
- if args.wedge_axes is not None:
826
- args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
786
+ args.template = abspath(args.template)
787
+ if args.template_mask is not None:
788
+ args.template_mask = abspath(args.template_mask)
827
789
 
828
790
  return args
829
791
 
@@ -864,7 +826,7 @@ def main():
864
826
  name="Target",
865
827
  data={
866
828
  "Initial Shape": initial_shape,
867
- "Sampling Rate": tuple(np.round(target.sampling_rate, 2)),
829
+ "Sampling Rate": _format_sampling(target.sampling_rate),
868
830
  "Final Shape": target.shape,
869
831
  },
870
832
  )
@@ -874,27 +836,28 @@ def main():
874
836
  name="Target Mask",
875
837
  data={
876
838
  "Initial Shape": initial_shape,
877
- "Sampling Rate": tuple(np.round(target_mask.sampling_rate, 2)),
839
+ "Sampling Rate": _format_sampling(target_mask.sampling_rate),
878
840
  "Final Shape": target_mask.shape,
879
841
  },
880
842
  )
881
843
 
882
844
  initial_shape = template.shape
883
845
  translation = np.zeros(len(template.shape), dtype=np.float32)
884
- if not args.no_centering:
846
+ if args.centering:
885
847
  template, translation = template.centered(0)
848
+
886
849
  print_block(
887
850
  name="Template",
888
851
  data={
889
852
  "Initial Shape": initial_shape,
890
- "Sampling Rate": tuple(np.round(template.sampling_rate, 2)),
853
+ "Sampling Rate": _format_sampling(template.sampling_rate),
891
854
  "Final Shape": template.shape,
892
855
  },
893
856
  )
894
857
 
895
858
  if template_mask is None:
896
859
  template_mask = template.empty
897
- if not args.no_centering:
860
+ if not args.centering:
898
861
  enclosing_box = template.minimum_enclosing_box(
899
862
  0, use_geometric_center=False
900
863
  )
@@ -919,7 +882,7 @@ def main():
919
882
  name="Template Mask",
920
883
  data={
921
884
  "Inital Shape": initial_shape,
922
- "Sampling Rate": tuple(np.round(template_mask.sampling_rate, 2)),
885
+ "Sampling Rate": _format_sampling(template_mask.sampling_rate),
923
886
  "Final Shape": template_mask.shape,
924
887
  },
925
888
  )
@@ -930,65 +893,78 @@ def main():
930
893
  template.data, noise_proportion=1.0, normalize_power=False
931
894
  )
932
895
 
896
+ callback_class = MaxScoreOverRotations
897
+ if args.peak_calling:
898
+ callback_class = PeakCallerMaximumFilter
899
+
900
+ if args.orientations is not None:
901
+ callback_class = MaxScoreOverRotationsConstrained
902
+
933
903
  # Determine suitable backend for the selected operation
934
904
  available_backends = be.available_backends()
935
- if args.backend is not None:
936
- req_backend = args.backend
937
- if req_backend not in available_backends:
938
- raise ValueError("Requested backend is not available.")
939
- available_backends = [req_backend]
940
-
941
- be_selection = ("numpyfftw", "pytorch", "jax", "mlx")
942
- if args.use_gpu:
943
- args.cores = len(args.gpu_indices)
944
- be_selection = ("pytorch", "cupy", "jax")
945
- if args.use_mixed_precision:
946
- be_selection = tuple(x for x in be_selection if x in ("cupy", "numpyfftw"))
947
-
948
- available_backends = [x for x in available_backends if x in be_selection]
949
- if args.peak_calling:
950
- if "jax" in available_backends:
951
- available_backends.remove("jax")
952
- if args.use_gpu and "pytorch" in available_backends:
953
- available_backends = ("pytorch",)
954
-
955
- # dim_match = len(template.shape) == len(target.shape) <= 3
956
- # if dim_match and args.use_gpu and "jax" in available_backends:
957
- # args.interpolation_order = 1
958
- # available_backends = ["jax"]
959
-
960
- backend_preference = ("numpyfftw", "pytorch", "jax", "mlx")
961
- if args.use_gpu:
962
- backend_preference = ("cupy", "pytorch", "jax")
963
- for pref in backend_preference:
964
- if pref not in available_backends:
965
- continue
966
- be.change_backend(pref)
967
- if pref == "pytorch":
968
- be.change_backend(pref, device="cuda" if args.use_gpu else "cpu")
969
-
970
- if args.use_mixed_precision:
971
- be.change_backend(
972
- backend_name=pref,
973
- default_dtype=be._array_backend.float16,
974
- complex_dtype=be._array_backend.complex64,
975
- default_dtype_int=be._array_backend.int16,
976
- )
977
- break
905
+ if args.backend not in available_backends:
906
+ raise ValueError("Requested backend is not available.")
907
+ if args.backend == "jax" and callback_class != MaxScoreOverRotations:
908
+ raise ValueError(
909
+ "Jax backend only supports the MaxScoreOverRotations analyzer."
910
+ )
978
911
 
979
- if pref == "pytorch" and args.interpolation_order == 3:
912
+ if args.interpolation_order == 3 and args.backend in ("jax", "pytorch"):
980
913
  warnings.warn(
981
- "Pytorch does not support --interpolation_order 3, setting it to 1."
914
+ "Jax and pytorch do not support interpolation order 3, setting it to 1."
982
915
  )
983
916
  args.interpolation_order = 1
984
917
 
918
+ if args.backend in ("pytorch", "cupy", "jax"):
919
+ gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
920
+ if gpu_devices is None:
921
+ warnings.warn(
922
+ "No GPU indices provided and CUDA_VISIBLE_DEVICES is not set. "
923
+ "Assuming device 0.",
924
+ )
925
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
926
+
927
+ args.cores = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
928
+ args.gpu_indices = [
929
+ int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
930
+ ]
931
+
932
+ # Finally set the desired backend
933
+ device = "cuda"
934
+ args.use_gpu = False
935
+ be.change_backend(args.backend)
936
+ if args.backend in ("jax", "pytorch", "cupy"):
937
+ args.use_gpu = True
938
+
939
+ if args.backend == "pytorch":
940
+ try:
941
+ be.change_backend("pytorch", device=device)
942
+ # Trigger exception if not compiled with device
943
+ be.get_available_memory()
944
+ except Exception as e:
945
+ print(e)
946
+ device = "cpu"
947
+ args.use_gpu = False
948
+ be.change_backend("pytorch", device=device)
949
+
950
+ # TODO: Make the inverse casting from complex64 -> float 16 stable
951
+ # if args.use_mixed_precision:
952
+ # be.change_backend(
953
+ # backend_name=args.backend,
954
+ # float_dtype=be._array_backend.float16,
955
+ # complex_dtype=be._array_backend.complex64,
956
+ # int_dtype=be._array_backend.int16,
957
+ # device=device,
958
+ # )
959
+
985
960
  available_memory = be.get_available_memory() * be.device_count()
986
961
  if args.memory is None:
987
962
  args.memory = int(args.memory_scaling * available_memory)
988
963
 
989
- callback_class = MaxScoreOverRotations
990
- if args.peak_calling:
991
- callback_class = PeakCallerMaximumFilter
964
+ if args.orientations_uncertainty is not None:
965
+ args.orientations_uncertainty = tuple(
966
+ int(x) for x in args.orientations_uncertainty.split(",")
967
+ )
992
968
 
993
969
  matching_data = MatchingData(
994
970
  target=target,
@@ -1018,7 +994,7 @@ def main():
1018
994
  options = {
1019
995
  "Angular Sampling": f"{args.angular_sampling}"
1020
996
  f" [{matching_data.rotations.shape[0]} rotations]",
1021
- "Center Template": not args.no_centering,
997
+ "Center Template": args.centering,
1022
998
  "Scramble Template": args.scramble_phases,
1023
999
  "Invert Contrast": args.invert_target_contrast,
1024
1000
  "Extend Target Edges": args.pad_edges,
@@ -1061,12 +1037,7 @@ def main():
1061
1037
  }
1062
1038
  if args.ctf_file is not None or args.defocus is not None:
1063
1039
  filter_args["CTF File"] = args.ctf_file
1064
- filter_args["Defocus"] = args.defocus
1065
- filter_args["Phase Shift"] = args.phase_shift
1066
1040
  filter_args["Flip Phase"] = args.no_flip_phase
1067
- filter_args["Acceleration Voltage"] = args.acceleration_voltage
1068
- filter_args["Spherical Aberration"] = args.spherical_aberration
1069
- filter_args["Amplitude Contrast"] = args.amplitude_contrast
1070
1041
  filter_args["Correct Defocus"] = args.correct_defocus_gradient
1071
1042
 
1072
1043
  filter_args = {k: v for k, v in filter_args.items() if v is not None}
@@ -1079,13 +1050,25 @@ def main():
1079
1050
 
1080
1051
  analyzer_args = {
1081
1052
  "score_threshold": args.score_threshold,
1082
- "number_of_peaks": args.number_of_peaks,
1053
+ "num_peaks": args.num_peaks,
1083
1054
  "min_distance": max(template.shape) // 3,
1084
1055
  "use_memmap": args.use_memmap,
1085
1056
  }
1057
+ if args.orientations is not None:
1058
+ analyzer_args["reference"] = (0, 0, 1)
1059
+ analyzer_args["cone_angle"] = args.orientations_cone
1060
+ analyzer_args["acceptance_radius"] = args.orientations_uncertainty
1061
+ analyzer_args["positions"] = args.orientations.translations
1062
+ analyzer_args["rotations"] = euler_to_rotationmatrix(
1063
+ args.orientations.rotations
1064
+ )
1065
+
1086
1066
  print_block(
1087
1067
  name="Analyzer",
1088
- data={"Analyzer": callback_class, **analyzer_args},
1068
+ data={
1069
+ "Analyzer": callback_class,
1070
+ **{sanitize_name(k): v for k, v in analyzer_args.items()},
1071
+ },
1089
1072
  label_width=max(len(key) for key in options.keys()) + 3,
1090
1073
  )
1091
1074
  print("\n" + "-" * 80)
@@ -1111,18 +1094,6 @@ def main():
1111
1094
  )
1112
1095
 
1113
1096
  candidates = list(candidates) if candidates is not None else []
1114
- if issubclass(callback_class, MaxScoreOverRotations):
1115
- if target_mask is not None and args.score != "MCC":
1116
- candidates[0] *= target_mask.data
1117
- with warnings.catch_warnings():
1118
- warnings.simplefilter("ignore", category=UserWarning)
1119
- nbytes = be.datatype_bytes(be._float_dtype)
1120
- dtype = np.float32 if nbytes == 4 else np.float16
1121
- rot_dim = matching_data.rotations.shape[1]
1122
- candidates[3] = {
1123
- x: np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
1124
- for i, x in candidates[3].items()
1125
- }
1126
1097
  candidates.append((target.origin, template.origin, template.sampling_rate, args))
1127
1098
  write_pickle(data=candidates, filename=args.output)
1128
1099