pytme 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.2.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/match_template.py +213 -196
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/postprocess.py +40 -78
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/preprocess.py +4 -5
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/preprocessor_gui.py +49 -103
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/pytme_runner.py +46 -69
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/METADATA +3 -2
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/RECORD +68 -65
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +213 -196
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +40 -78
- scripts/preprocess.py +4 -5
- scripts/preprocessor_gui.py +49 -103
- scripts/pytme_runner.py +46 -69
- tests/preprocessing/test_compose.py +31 -30
- tests/preprocessing/test_frequency_filters.py +17 -32
- tests/preprocessing/test_preprocessor.py +0 -19
- tests/preprocessing/test_utils.py +13 -1
- tests/test_analyzer.py +2 -10
- tests/test_backends.py +47 -18
- tests/test_density.py +72 -13
- tests/test_extensions.py +1 -0
- tests/test_matching_cli.py +23 -9
- tests/test_matching_exhaustive.py +5 -5
- tests/test_matching_utils.py +3 -3
- tests/test_orientations.py +12 -0
- tests/test_rotations.py +13 -23
- tests/test_structure.py +1 -7
- tme/__version__.py +1 -1
- tme/analyzer/aggregation.py +47 -16
- tme/analyzer/base.py +34 -0
- tme/analyzer/peaks.py +26 -13
- tme/analyzer/proxy.py +14 -0
- tme/backends/_jax_utils.py +91 -68
- tme/backends/cupy_backend.py +6 -19
- tme/backends/jax_backend.py +103 -98
- tme/backends/matching_backend.py +0 -17
- tme/backends/mlx_backend.py +0 -29
- tme/backends/npfftw_backend.py +100 -97
- tme/backends/pytorch_backend.py +65 -78
- tme/cli.py +2 -2
- tme/density.py +44 -57
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/_utils.py +52 -24
- tme/filters/bandpass.py +99 -105
- tme/filters/compose.py +133 -39
- tme/filters/ctf.py +51 -102
- tme/filters/reconstruction.py +67 -122
- tme/filters/wedge.py +296 -325
- tme/filters/whitening.py +39 -75
- tme/mask.py +2 -2
- tme/matching_data.py +87 -15
- tme/matching_exhaustive.py +70 -120
- tme/matching_optimization.py +9 -63
- tme/matching_scores.py +261 -100
- tme/matching_utils.py +150 -91
- tme/memory.py +1 -0
- tme/orientations.py +17 -3
- tme/preprocessor.py +0 -239
- tme/rotations.py +102 -70
- tme/structure.py +601 -631
- tme/types.py +1 -0
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/WHEEL +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/top_level.txt +0 -0
scripts/match_template.py
CHANGED
@@ -19,8 +19,9 @@ import numpy as np
|
|
19
19
|
|
20
20
|
from tme.backends import backend as be
|
21
21
|
from tme import Density, __version__, Orientations
|
22
|
-
from tme.matching_utils import
|
23
|
-
from tme.matching_exhaustive import
|
22
|
+
from tme.matching_utils import write_pickle
|
23
|
+
from tme.matching_exhaustive import match_exhaustive
|
24
|
+
from tme.matching_scores import MATCHING_EXHAUSTIVE_REGISTER
|
24
25
|
from tme.rotations import (
|
25
26
|
get_cone_rotations,
|
26
27
|
get_rotation_matrices,
|
@@ -37,6 +38,7 @@ from tme.filters import (
|
|
37
38
|
Wedge,
|
38
39
|
Compose,
|
39
40
|
BandPass,
|
41
|
+
ShiftFourier,
|
40
42
|
CTFReconstructed,
|
41
43
|
WedgeReconstructed,
|
42
44
|
ReconstructFromTilt,
|
@@ -129,21 +131,14 @@ def parse_rotation_logic(args, ndim):
|
|
129
131
|
|
130
132
|
|
131
133
|
def compute_schedule(
|
132
|
-
args,
|
133
|
-
matching_data: MatchingData,
|
134
|
-
callback_class,
|
135
|
-
pad_edges: bool = False,
|
134
|
+
args, matching_data: MatchingData, callback_class, use_gpu: bool = False
|
136
135
|
):
|
137
|
-
# User requested target padding
|
138
|
-
if args.pad_edges is True:
|
139
|
-
pad_edges = True
|
140
|
-
|
141
136
|
splits, schedule = matching_data.computation_schedule(
|
142
137
|
matching_method=args.score,
|
143
138
|
analyzer_method=callback_class.__name__,
|
144
|
-
use_gpu=
|
139
|
+
use_gpu=use_gpu,
|
145
140
|
pad_fourier=False,
|
146
|
-
pad_target_edges=pad_edges,
|
141
|
+
pad_target_edges=args.pad_edges,
|
147
142
|
available_memory=args.memory,
|
148
143
|
max_cores=args.cores,
|
149
144
|
)
|
@@ -155,53 +150,63 @@ def compute_schedule(
|
|
155
150
|
)
|
156
151
|
exit(-1)
|
157
152
|
|
153
|
+
# Padding is required to avoid artifacts so setting it
|
158
154
|
n_splits = np.prod(list(splits.values()))
|
159
|
-
if pad_edges is False and len(matching_data._target_dim) == 0 and n_splits > 1:
|
155
|
+
if args.pad_edges is False and len(matching_data._target_dim) == 0 and n_splits > 1:
|
156
|
+
warnings.warn("Setting --pad-edges to avoid artifacts from splitting.")
|
160
157
|
args.pad_edges = True
|
161
|
-
return compute_schedule(args, matching_data, callback_class,
|
158
|
+
return compute_schedule(args, matching_data, callback_class, use_gpu)
|
162
159
|
return splits, schedule
|
163
160
|
|
164
161
|
|
165
162
|
def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
|
166
163
|
template_filter, target_filter = [], []
|
167
164
|
|
168
|
-
if args.tilt_angles is None:
|
169
|
-
args.tilt_angles = args.ctf_file
|
170
|
-
|
171
165
|
wedge = None
|
172
166
|
if args.tilt_angles is not None:
|
173
167
|
try:
|
174
168
|
wedge = Wedge.from_file(args.tilt_angles)
|
175
169
|
wedge.weight_type = args.tilt_weighting
|
176
|
-
|
170
|
+
|
171
|
+
# Avoid reconstructing the 3D wedge from individual tilts
|
172
|
+
if args.tilt_weighting in ("angle", None) and not args.match_projection:
|
177
173
|
wedge = WedgeReconstructed(
|
178
174
|
angles=wedge.angles,
|
179
175
|
weight_wedge=args.tilt_weighting == "angle",
|
180
176
|
)
|
177
|
+
|
181
178
|
except (FileNotFoundError, AttributeError):
|
182
|
-
tilt_start, tilt_stop = args.tilt_angles.split(",")
|
183
|
-
tilt_start, tilt_stop = abs(float(tilt_start)), abs(float(tilt_stop))
|
184
179
|
wedge = WedgeReconstructed(
|
185
|
-
angles=
|
180
|
+
angles=args.tilt_angles,
|
186
181
|
create_continuous_wedge=True,
|
187
182
|
weight_wedge=False,
|
188
183
|
reconstruction_filter=args.reconstruction_filter,
|
189
184
|
)
|
190
|
-
wedge.opening_axis, wedge.tilt_axis = args.wedge_axes
|
191
|
-
|
192
|
-
wedge_target = WedgeReconstructed(
|
193
|
-
angles=wedge.angles,
|
194
|
-
weight_wedge=False,
|
195
|
-
create_continuous_wedge=True,
|
196
|
-
opening_axis=wedge.opening_axis,
|
197
|
-
tilt_axis=wedge.tilt_axis,
|
198
|
-
)
|
199
185
|
|
200
186
|
wedge.sampling_rate = template.sampling_rate
|
201
|
-
|
187
|
+
wedge.opening_axis, wedge.tilt_axis = args.wedge_axes
|
188
|
+
template_filter.append(wedge)
|
189
|
+
|
190
|
+
# When projection matching we can reuse the template wedge mask
|
191
|
+
wedge_target = wedge
|
192
|
+
if not args.match_projection:
|
193
|
+
wedge_target = WedgeReconstructed(
|
194
|
+
angles=wedge.angles,
|
195
|
+
weight_wedge=False,
|
196
|
+
create_continuous_wedge=True,
|
197
|
+
opening_axis=wedge.opening_axis,
|
198
|
+
tilt_axis=wedge.tilt_axis,
|
199
|
+
)
|
202
200
|
|
201
|
+
wedge_target.sampling_rate = template.sampling_rate
|
202
|
+
else:
|
203
|
+
n_angles, n_tilts = len(wedge_target.angles), target.shape[0]
|
204
|
+
if n_angles != n_tilts:
|
205
|
+
raise ValueError(
|
206
|
+
f"Target contains {n_tilts} tilts, but the input specified "
|
207
|
+
f"{n_angles} tilt angles."
|
208
|
+
)
|
203
209
|
target_filter.append(wedge_target)
|
204
|
-
template_filter.append(wedge)
|
205
210
|
|
206
211
|
if args.ctf_file is not None or args.defocus is not None:
|
207
212
|
try:
|
@@ -219,15 +224,19 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
219
224
|
if len(ctf.angles) == 0:
|
220
225
|
ctf.angles = wedge.angles
|
221
226
|
|
227
|
+
# There are several ways we can end up here. Bottom line, we are using
|
228
|
+
# a non-reconstructed wedge, which contains a different number of tilts
|
229
|
+
# than the ctf. We use defocus_x, as not all ctf_files specify angles.
|
222
230
|
n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
|
223
|
-
if (n_tilts_ctfs != n_tils_angles) and
|
231
|
+
if (n_tilts_ctfs != n_tils_angles) and type(wedge) is Wedge:
|
224
232
|
raise ValueError(
|
225
233
|
f"CTF file contains {n_tilts_ctfs} tilt, but recieved "
|
226
234
|
f"{n_tils_angles} tilt angles. Expected one angle per tilt"
|
227
235
|
)
|
228
236
|
|
229
237
|
except (FileNotFoundError, AttributeError):
|
230
|
-
|
238
|
+
ctf_cl = CTFReconstructed if not args.match_projection else CTF
|
239
|
+
ctf = ctf_cl(
|
231
240
|
defocus_x=args.defocus,
|
232
241
|
phase_shift=args.phase_shift,
|
233
242
|
amplitude_contrast=args.amplitude_contrast,
|
@@ -237,10 +246,9 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
237
246
|
ctf.flip_phase = args.no_flip_phase
|
238
247
|
ctf.sampling_rate = template.sampling_rate
|
239
248
|
ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
|
240
|
-
ctf.correct_defocus_gradient = args.correct_defocus_gradient
|
241
249
|
template_filter.append(ctf)
|
242
250
|
|
243
|
-
if args.lowpass or args.highpass is not None:
|
251
|
+
if args.lowpass is not None or args.highpass is not None:
|
244
252
|
lowpass, highpass = args.lowpass, args.highpass
|
245
253
|
if args.pass_format == "voxel":
|
246
254
|
if lowpass is not None:
|
@@ -255,49 +263,58 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
255
263
|
|
256
264
|
try:
|
257
265
|
if args.lowpass >= args.highpass:
|
258
|
-
|
266
|
+
raise ValueError("--lowpass should be smaller than --highpass.")
|
259
267
|
except Exception:
|
260
268
|
pass
|
261
269
|
|
262
|
-
|
270
|
+
bp_cl = BandPassReconstructed if not args.match_projection else BandPass
|
271
|
+
bandpass = bp_cl(
|
263
272
|
use_gaussian=args.no_pass_smooth,
|
264
273
|
lowpass=lowpass,
|
265
274
|
highpass=highpass,
|
266
275
|
sampling_rate=template.sampling_rate,
|
267
276
|
)
|
277
|
+
bandpass.opening_axis, bandpass.tilt_axis = args.wedge_axes
|
268
278
|
template_filter.append(bandpass)
|
269
279
|
target_filter.append(bandpass)
|
270
280
|
|
281
|
+
if not args.match_projection:
|
282
|
+
rec_filt = (Wedge, CTF)
|
283
|
+
needs_reconstruction = sum(type(x) in rec_filt for x in template_filter)
|
284
|
+
if needs_reconstruction > 0 and args.reconstruction_filter is None:
|
285
|
+
warnings.warn(
|
286
|
+
"Consider using a --reconstruction_filter such as 'ram-lak' or 'ramp' "
|
287
|
+
"to avoid artifacts from reconstruction using weighted backprojection."
|
288
|
+
)
|
289
|
+
|
290
|
+
template_filter = sorted(
|
291
|
+
template_filter, key=lambda x: type(x) in rec_filt, reverse=True
|
292
|
+
)
|
293
|
+
if needs_reconstruction > 0:
|
294
|
+
relevant_filters = [x for x in template_filter if type(x) in rec_filt]
|
295
|
+
if len(relevant_filters) == 0:
|
296
|
+
raise ValueError("Filters require ")
|
297
|
+
|
298
|
+
reconstruction_filter = ReconstructFromTilt(
|
299
|
+
reconstruction_filter=args.reconstruction_filter,
|
300
|
+
interpolation_order=args.reconstruction_interpolation_order,
|
301
|
+
angles=relevant_filters[0].angles,
|
302
|
+
opening_axis=args.wedge_axes[0],
|
303
|
+
tilt_axis=args.wedge_axes[1],
|
304
|
+
)
|
305
|
+
template_filter.insert(needs_reconstruction, reconstruction_filter)
|
306
|
+
else:
|
307
|
+
template_filter.append(ShiftFourier())
|
308
|
+
if len(target_filter):
|
309
|
+
target_filter.append(ShiftFourier())
|
310
|
+
|
311
|
+
# LinearWhiteningFilter does not support working on tilts yet, hence we
|
312
|
+
# can safely evaluate it after all other filters
|
271
313
|
if args.whiten_spectrum:
|
272
314
|
whitening_filter = LinearWhiteningFilter()
|
273
315
|
template_filter.append(whitening_filter)
|
274
316
|
target_filter.append(whitening_filter)
|
275
317
|
|
276
|
-
rec_filt = (Wedge, CTF)
|
277
|
-
needs_reconstruction = sum(type(x) in rec_filt for x in template_filter)
|
278
|
-
if needs_reconstruction > 0 and args.reconstruction_filter is None:
|
279
|
-
warnings.warn(
|
280
|
-
"Consider using a --reconstruction_filter such as 'ram-lak' or 'ramp' "
|
281
|
-
"to avoid artifacts from reconstruction using weighted backprojection."
|
282
|
-
)
|
283
|
-
|
284
|
-
template_filter = sorted(
|
285
|
-
template_filter, key=lambda x: type(x) in rec_filt, reverse=True
|
286
|
-
)
|
287
|
-
if needs_reconstruction > 0:
|
288
|
-
relevant_filters = [x for x in template_filter if type(x) in rec_filt]
|
289
|
-
if len(relevant_filters) == 0:
|
290
|
-
raise ValueError("Filters require ")
|
291
|
-
|
292
|
-
reconstruction_filter = ReconstructFromTilt(
|
293
|
-
reconstruction_filter=args.reconstruction_filter,
|
294
|
-
interpolation_order=args.reconstruction_interpolation_order,
|
295
|
-
angles=relevant_filters[0].angles,
|
296
|
-
opening_axis=args.wedge_axes[0],
|
297
|
-
tilt_axis=args.wedge_axes[1],
|
298
|
-
)
|
299
|
-
template_filter.insert(needs_reconstruction, reconstruction_filter)
|
300
|
-
|
301
318
|
template_filter = Compose(template_filter) if len(template_filter) else None
|
302
319
|
target_filter = Compose(target_filter) if len(target_filter) else None
|
303
320
|
if args.no_filter_target:
|
@@ -359,8 +376,7 @@ def parse_args():
|
|
359
376
|
"--invert-target-contrast",
|
360
377
|
action="store_true",
|
361
378
|
default=False,
|
362
|
-
help="Invert
|
363
|
-
"template has not been inverted.",
|
379
|
+
help="Invert contrast by multiplication with negative one.",
|
364
380
|
)
|
365
381
|
io_group.add_argument(
|
366
382
|
"--scramble-phases",
|
@@ -411,6 +427,13 @@ def parse_args():
|
|
411
427
|
choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
|
412
428
|
help="Template matching scoring function.",
|
413
429
|
)
|
430
|
+
scoring_group.add_argument(
|
431
|
+
"--background-correction",
|
432
|
+
choices=["phase-scrambling"],
|
433
|
+
required=False,
|
434
|
+
help="Transform cross-correlation into SNR-like values using a given method: "
|
435
|
+
"'phase-scrambling' uses a phase-scrambled template as background",
|
436
|
+
)
|
414
437
|
|
415
438
|
angular_group = parser.add_argument_group("Angular Sampling")
|
416
439
|
angular_exclusive = angular_group.add_mutually_exclusive_group(required=True)
|
@@ -493,9 +516,8 @@ def parse_args():
|
|
493
516
|
computation_group.add_argument(
|
494
517
|
"--gpu-indices",
|
495
518
|
type=str,
|
496
|
-
default=
|
497
|
-
help="
|
498
|
-
"CUDA_VISIBLE_DEVICES will be used.",
|
519
|
+
default=os.environ.get("CUDA_VISIBLE_DEVICES"),
|
520
|
+
help="GPU indices, e.g., '0,1,2', defaults to CUDA_VISIBLE_DEVICES.",
|
499
521
|
)
|
500
522
|
computation_group.add_argument(
|
501
523
|
"--memory",
|
@@ -527,15 +549,13 @@ def parse_args():
|
|
527
549
|
"--lowpass",
|
528
550
|
type=float,
|
529
551
|
required=False,
|
530
|
-
help="Resolution to lowpass filter template and target to
|
531
|
-
"as the sampling rate of template and target (typically Ångstrom).",
|
552
|
+
help="Resolution to lowpass filter template and target to.",
|
532
553
|
)
|
533
554
|
filter_group.add_argument(
|
534
555
|
"--highpass",
|
535
556
|
type=float,
|
536
557
|
required=False,
|
537
|
-
help="Resolution to highpass filter template and target to
|
538
|
-
"as the sampling rate of template and target (typically Ångstrom).",
|
558
|
+
help="Resolution to highpass filter template and target to.",
|
539
559
|
)
|
540
560
|
filter_group.add_argument(
|
541
561
|
"--no-pass-smooth",
|
@@ -549,14 +569,13 @@ def parse_args():
|
|
549
569
|
required=False,
|
550
570
|
default="sampling_rate",
|
551
571
|
choices=["sampling_rate", "voxel", "frequency"],
|
552
|
-
help="How values passed to --lowpass and --highpass should be interpreted. "
|
553
|
-
"Defaults to unit of sampling_rate, e.g., 40 Angstrom.",
|
572
|
+
help="How values passed to --lowpass and --highpass should be interpreted. ",
|
554
573
|
)
|
555
574
|
filter_group.add_argument(
|
556
575
|
"--whiten-spectrum",
|
557
576
|
action="store_true",
|
558
577
|
default=None,
|
559
|
-
help="Apply spectral whitening to template and target
|
578
|
+
help="Apply spectral whitening to template and target.",
|
560
579
|
)
|
561
580
|
filter_group.add_argument(
|
562
581
|
"--wedge-axes",
|
@@ -603,6 +622,7 @@ def parse_args():
|
|
603
622
|
type=int,
|
604
623
|
default=1,
|
605
624
|
required=False,
|
625
|
+
choices=[0, 1, 2, 3, 4, 5],
|
606
626
|
help="Analogous to --interpolation-order but for reconstruction.",
|
607
627
|
)
|
608
628
|
filter_group.add_argument(
|
@@ -664,40 +684,32 @@ def parse_args():
|
|
664
684
|
required=False,
|
665
685
|
help="Do not perform phase-flipping CTF correction.",
|
666
686
|
)
|
667
|
-
ctf_group.add_argument(
|
668
|
-
"--correct-defocus-gradient",
|
669
|
-
action="store_true",
|
670
|
-
required=False,
|
671
|
-
help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
|
672
|
-
"defocus gradients.",
|
673
|
-
)
|
674
687
|
|
675
688
|
performance_group = parser.add_argument_group("Performance")
|
676
689
|
performance_group.add_argument(
|
677
690
|
"--centering",
|
678
691
|
action="store_true",
|
679
|
-
help="
|
692
|
+
help="Translate the template's center of mass to the center of the box.",
|
680
693
|
)
|
681
694
|
performance_group.add_argument(
|
682
695
|
"--pad-edges",
|
683
696
|
action="store_true",
|
684
697
|
default=False,
|
685
|
-
help="
|
686
|
-
"activated automatically if splitting is required to avoid boundary artifacts.",
|
698
|
+
help="Zero pad the target. Defaults to True if splitting is required..",
|
687
699
|
)
|
688
700
|
performance_group.add_argument(
|
689
701
|
"--interpolation-order",
|
690
702
|
required=False,
|
691
703
|
type=int,
|
692
704
|
default=None,
|
693
|
-
|
694
|
-
"and pytorch
|
705
|
+
choices=[0, 1, 2, 3, 4, 5],
|
706
|
+
help="Spline order for rotation, default is 3 and 1 for jax and pytorch.",
|
695
707
|
)
|
696
708
|
performance_group.add_argument(
|
697
709
|
"--use-memmap",
|
698
710
|
action="store_true",
|
699
711
|
default=False,
|
700
|
-
help="Memmap
|
712
|
+
help="Memmap analyzer data, useful for matching on very large inputs.",
|
701
713
|
)
|
702
714
|
|
703
715
|
analyzer_group = parser.add_argument_group("Output / Analysis")
|
@@ -706,7 +718,7 @@ def parse_args():
|
|
706
718
|
required=False,
|
707
719
|
type=float,
|
708
720
|
default=0,
|
709
|
-
help="Minimum template matching scores to consider
|
721
|
+
help="Minimum template matching scores to consider.",
|
710
722
|
)
|
711
723
|
analyzer_group.add_argument(
|
712
724
|
"-p",
|
@@ -724,30 +736,20 @@ def parse_args():
|
|
724
736
|
args = parser.parse_args()
|
725
737
|
args.version = __version__
|
726
738
|
|
727
|
-
if args.
|
728
|
-
|
729
|
-
if args.backend in ("jax", "pytorch"):
|
730
|
-
args.interpolation_order = 1
|
731
|
-
args.reconstruction_interpolation_order = 1
|
732
|
-
|
733
|
-
if args.interpolation_order < 0:
|
734
|
-
args.interpolation_order = None
|
735
|
-
|
736
|
-
if args.temp_directory is None:
|
737
|
-
args.temp_directory = gettempdir()
|
738
|
-
|
739
|
-
os.environ["TMPDIR"] = args.temp_directory
|
740
|
-
if args.gpu_indices is not None:
|
741
|
-
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_indices
|
739
|
+
if args.temp_directory is not None:
|
740
|
+
os.environ["TMPDIR"] = args.temp_directory
|
742
741
|
|
743
|
-
|
742
|
+
# Tilt angles can be specified as range or using a suitable input file
|
743
|
+
is_file = exists(args.tilt_angles) if args.tilt_angles is not None else False
|
744
|
+
if args.tilt_angles is not None and not is_file:
|
744
745
|
try:
|
745
|
-
float(args.tilt_angles.split(",")
|
746
|
+
args.tilt_angles = tuple(abs(float(x)) for x in args.tilt_angles.split(","))
|
746
747
|
except Exception:
|
747
748
|
raise ValueError(f"{args.tilt_angles} is not a file nor a range.")
|
748
749
|
|
750
|
+
# Since both Wedge.from_file and CTF.from_file parse similar inputs, we can
|
751
|
+
# fall back to assigning the ctf_file to args.tilt_angles
|
749
752
|
if args.ctf_file is not None and args.tilt_angles is None:
|
750
|
-
# Check if tilt angles can be extracted from CTF specification
|
751
753
|
try:
|
752
754
|
ctf = CTF.from_file(args.ctf_file)
|
753
755
|
if ctf.angles is None:
|
@@ -758,7 +760,14 @@ def parse_args():
|
|
758
760
|
"Need to specify --tilt-angles when not provided in --ctf-file."
|
759
761
|
)
|
760
762
|
|
761
|
-
|
763
|
+
# For projection matching we cannot use continuous wedge masks
|
764
|
+
args.match_projection = False
|
765
|
+
if not is_file and args.match_projection:
|
766
|
+
raise ValueError(
|
767
|
+
"Projection angles are required via --tilt-angles or --ctf-file."
|
768
|
+
)
|
769
|
+
|
770
|
+
# Handle constrained matching inputs
|
762
771
|
if args.orientations is not None:
|
763
772
|
orientations = Orientations.from_file(args.orientations)
|
764
773
|
orientations.translations = np.divide(
|
@@ -766,6 +775,64 @@ def parse_args():
|
|
766
775
|
)
|
767
776
|
args.orientations = orientations
|
768
777
|
|
778
|
+
if args.orientations_uncertainty is not None:
|
779
|
+
args.orientations_uncertainty = tuple(
|
780
|
+
int(x) for x in args.orientations_uncertainty.split(",")
|
781
|
+
)
|
782
|
+
|
783
|
+
# Handle backend specificities
|
784
|
+
if args.interpolation_order is None:
|
785
|
+
args.interpolation_order = 3
|
786
|
+
if args.backend in ("jax", "pytorch"):
|
787
|
+
args.interpolation_order = 1
|
788
|
+
args.reconstruction_interpolation_order = 1
|
789
|
+
|
790
|
+
# This flag is not passed to backend yet, but might aswell be verbose about it
|
791
|
+
if args.interpolation_order != 1 and args.backend == "jax":
|
792
|
+
warnings.warn("Setting interpolation order to order jax supports (1).")
|
793
|
+
args.interpolation_order = 1
|
794
|
+
args.reconstruction_interpolation_order = 1
|
795
|
+
|
796
|
+
if args.interpolation_order == 3 and args.backend == "pytorch":
|
797
|
+
warnings.warn("Pytorch does not support order 3, changing it to 1.")
|
798
|
+
args.interpolation_order = 1
|
799
|
+
if args.reconstruction_interpolation_order == 3:
|
800
|
+
args.reconstruction_interpolation_order = args.interpolation_order
|
801
|
+
|
802
|
+
# Handle GPU device specification for suitable backends
|
803
|
+
if args.backend in ("pytorch", "cupy", "jax"):
|
804
|
+
if args.gpu_indices is None:
|
805
|
+
warnings.warn(
|
806
|
+
"No GPU indices provided and CUDA_VISIBLE_DEVICES is not set. "
|
807
|
+
"Assuming device 0.",
|
808
|
+
)
|
809
|
+
args.gpu_indices = "0"
|
810
|
+
|
811
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_indices
|
812
|
+
args.gpu_indices = [int(x) for x in args.gpu_indices.split(",")]
|
813
|
+
args.cores = len(args.gpu_indices)
|
814
|
+
|
815
|
+
if args.backend == "jax" and args.peak_calling:
|
816
|
+
raise ValueError("Jax supports only subclasses of MaxScoreOverRotations.")
|
817
|
+
|
818
|
+
# Wedge axes do not have meaning for projections
|
819
|
+
args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
|
820
|
+
if args.match_projection:
|
821
|
+
args.wedge_axes = None, None
|
822
|
+
|
823
|
+
if args.match_projection and args.backend != "jax":
|
824
|
+
raise ValueError("Projection matching is only supported for --backend jax.")
|
825
|
+
|
826
|
+
# This is implicitly caught in the jax check above, but keeping it for future use
|
827
|
+
if args.match_projection and args.peak_calling:
|
828
|
+
raise ValueError("Peak calling is not yet supported for projection matching.")
|
829
|
+
|
830
|
+
if args.orientations is not None and args.peak_calling:
|
831
|
+
raise ValueError(
|
832
|
+
"Peak calling and constrained matching simultaneously is not yet supported."
|
833
|
+
)
|
834
|
+
|
835
|
+
# Avoid relative input specification
|
769
836
|
args.target = abspath(args.target)
|
770
837
|
if args.target_mask is not None:
|
771
838
|
args.target_mask = abspath(args.target_mask)
|
@@ -782,7 +849,6 @@ def main():
|
|
782
849
|
print_entry()
|
783
850
|
|
784
851
|
target = Density.from_file(args.target, use_memmap=True)
|
785
|
-
|
786
852
|
try:
|
787
853
|
template = Density.from_file(args.template)
|
788
854
|
except Exception:
|
@@ -800,27 +866,24 @@ def main():
|
|
800
866
|
)
|
801
867
|
|
802
868
|
if target.sampling_rate.size == template.sampling_rate.size:
|
803
|
-
|
869
|
+
sampling_rate_match = np.allclose(
|
804
870
|
np.round(target.sampling_rate, 2), np.round(template.sampling_rate, 2)
|
805
|
-
)
|
871
|
+
)
|
872
|
+
# For projection we omit the warning as the leading dimension has no sampling
|
873
|
+
if not sampling_rate_match and not args.match_projection:
|
806
874
|
warnings.warn(
|
807
875
|
f"Sampling rate mismatch detected: target={target.sampling_rate} "
|
808
876
|
f"template={template.sampling_rate}. Proceeding with user-provided "
|
809
877
|
f"values. Make sure this is intentional. "
|
810
878
|
)
|
811
879
|
|
812
|
-
template_mask = load_and_validate_mask(
|
813
|
-
|
814
|
-
)
|
815
|
-
target_mask = load_and_validate_mask(
|
816
|
-
mask_target=target, mask_path=args.target_mask, use_memmap=True
|
817
|
-
)
|
880
|
+
template_mask = load_and_validate_mask(template, args.template_mask)
|
881
|
+
target_mask = load_and_validate_mask(target, args.target_mask, use_memmap=True)
|
818
882
|
|
819
|
-
initial_shape = target.shape
|
820
883
|
print_block(
|
821
884
|
name="Target",
|
822
885
|
data={
|
823
|
-
"Initial Shape":
|
886
|
+
"Initial Shape": target.shape,
|
824
887
|
"Sampling Rate": _format_sampling(target.sampling_rate),
|
825
888
|
"Final Shape": target.shape,
|
826
889
|
},
|
@@ -830,16 +893,15 @@ def main():
|
|
830
893
|
print_block(
|
831
894
|
name="Target Mask",
|
832
895
|
data={
|
833
|
-
"Initial Shape":
|
896
|
+
"Initial Shape": target_mask.shape,
|
834
897
|
"Sampling Rate": _format_sampling(target_mask.sampling_rate),
|
835
898
|
"Final Shape": target_mask.shape,
|
836
899
|
},
|
837
900
|
)
|
838
901
|
|
839
902
|
initial_shape = template.shape
|
840
|
-
translation = np.zeros(len(template.shape), dtype=np.float32)
|
841
903
|
if args.centering:
|
842
|
-
template
|
904
|
+
template = template.centered(0)
|
843
905
|
|
844
906
|
print_block(
|
845
907
|
name="Template",
|
@@ -852,27 +914,12 @@ def main():
|
|
852
914
|
|
853
915
|
if template_mask is None:
|
854
916
|
template_mask = template.empty
|
855
|
-
if not args.centering:
|
856
|
-
enclosing_box = template.minimum_enclosing_box(
|
857
|
-
0, use_geometric_center=False
|
858
|
-
)
|
859
|
-
template_mask.adjust_box(enclosing_box)
|
860
|
-
|
861
|
-
template_mask.data[:] = 1
|
862
|
-
translation = np.zeros_like(translation)
|
863
917
|
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
translation = np.add(translation, origin_translation)
|
918
|
+
# Pre 0.3.2 we used to perform a rigid transform on the template mask to match
|
919
|
+
# the template origin, but this seems overly pedantic given the sporadic use
|
920
|
+
# of the origin parameter in the matching pipeline
|
921
|
+
template_mask.data = np.ones(template.shape, dtype=template.data.dtype)
|
869
922
|
|
870
|
-
template_mask = template_mask.rigid_transform(
|
871
|
-
rotation_matrix=np.eye(template_mask.data.ndim),
|
872
|
-
translation=-translation,
|
873
|
-
order=1,
|
874
|
-
)
|
875
|
-
template_mask.origin = template.origin.copy()
|
876
923
|
print_block(
|
877
924
|
name="Template Mask",
|
878
925
|
data={
|
@@ -883,71 +930,35 @@ def main():
|
|
883
930
|
)
|
884
931
|
print("\n" + "-" * 80)
|
885
932
|
|
886
|
-
if args.scramble_phases:
|
887
|
-
template.data = scramble_phases(template.data, noise_proportion=1.0)
|
888
|
-
|
889
933
|
callback_class = MaxScoreOverRotations
|
890
934
|
if args.orientations is not None:
|
891
935
|
callback_class = MaxScoreOverRotationsConstrained
|
892
936
|
elif args.peak_calling:
|
893
937
|
callback_class = PeakCallerMaximumFilter
|
894
938
|
|
895
|
-
#
|
896
|
-
|
897
|
-
|
898
|
-
raise ValueError("Requested backend is not available.")
|
899
|
-
if args.backend == "jax" and callback_class != MaxScoreOverRotations:
|
900
|
-
raise ValueError(
|
901
|
-
"Jax backend only supports the MaxScoreOverRotations analyzer."
|
902
|
-
)
|
903
|
-
|
904
|
-
if args.interpolation_order == 3 and args.backend in ("jax", "pytorch"):
|
905
|
-
warnings.warn(
|
906
|
-
"Jax and pytorch do not support interpolation order 3, setting it to 1."
|
907
|
-
)
|
908
|
-
args.interpolation_order = 1
|
909
|
-
|
910
|
-
if args.backend in ("pytorch", "cupy", "jax"):
|
911
|
-
gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
912
|
-
if gpu_devices is None:
|
913
|
-
warnings.warn(
|
914
|
-
"No GPU indices provided and CUDA_VISIBLE_DEVICES is not set. "
|
915
|
-
"Assuming device 0.",
|
916
|
-
)
|
917
|
-
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
918
|
-
|
919
|
-
args.cores = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
|
920
|
-
args.gpu_indices = [
|
921
|
-
int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
922
|
-
]
|
923
|
-
|
924
|
-
# Finally set the desired backend
|
925
|
-
device = "cuda"
|
926
|
-
args.use_gpu = False
|
927
|
-
be.change_backend(args.backend)
|
939
|
+
# We currently do not allow parallelizing angular searches in the GPU compatible
|
940
|
+
# backends, so we keep this flag to compute a suitable splitting schedule
|
941
|
+
use_gpu = False
|
928
942
|
if args.backend in ("jax", "pytorch", "cupy"):
|
929
|
-
|
943
|
+
use_gpu = True
|
930
944
|
|
945
|
+
# Finally set the requested backend
|
946
|
+
be.change_backend(args.backend)
|
931
947
|
if args.backend == "pytorch":
|
932
948
|
try:
|
933
|
-
be.change_backend("pytorch", device=
|
949
|
+
be.change_backend("pytorch", device="cuda")
|
934
950
|
# Trigger exception if not compiled with device
|
935
951
|
be.get_available_memory()
|
936
952
|
except Exception as e:
|
953
|
+
# Let the user know they did not compile with GPU devices
|
937
954
|
print(e)
|
938
|
-
|
939
|
-
|
940
|
-
be.change_backend("pytorch", device=device)
|
955
|
+
use_gpu = False
|
956
|
+
be.change_backend("pytorch", device="cpu")
|
941
957
|
|
942
958
|
available_memory = be.get_available_memory() * be.device_count()
|
943
959
|
if args.memory is None:
|
944
960
|
args.memory = int(args.memory_scaling * available_memory)
|
945
961
|
|
946
|
-
if args.orientations_uncertainty is not None:
|
947
|
-
args.orientations_uncertainty = tuple(
|
948
|
-
int(x) for x in args.orientations_uncertainty.split(",")
|
949
|
-
)
|
950
|
-
|
951
962
|
matching_data = MatchingData(
|
952
963
|
target=target,
|
953
964
|
template=template.data,
|
@@ -956,19 +967,23 @@ def main():
|
|
956
967
|
invert_target=args.invert_target_contrast,
|
957
968
|
rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
|
958
969
|
)
|
959
|
-
|
960
|
-
|
961
|
-
matching_data.template_filter, matching_data.target_filter = setup_filter(
|
962
|
-
args, template, target
|
963
|
-
)
|
970
|
+
if args.scramble_phases:
|
971
|
+
matching_data.template = matching_data.transform_template("phase_randomization")
|
964
972
|
|
965
973
|
matching_data.set_matching_dimension(
|
966
974
|
target_dim=target.metadata.get("batch_dimension", None),
|
967
975
|
template_dim=template.metadata.get("batch_dimension", None),
|
968
976
|
)
|
977
|
+
if args.match_projection:
|
978
|
+
matching_data.set_matching_dimension(target_dim=0)
|
979
|
+
|
969
980
|
args.batch_dims = tuple(int(x) for x in np.where(matching_data._batch_mask)[0])
|
981
|
+
matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
|
982
|
+
matching_data.template_filter, matching_data.target_filter = setup_filter(
|
983
|
+
args, template, target
|
984
|
+
)
|
970
985
|
|
971
|
-
splits, schedule = compute_schedule(args, matching_data, callback_class)
|
986
|
+
splits, schedule = compute_schedule(args, matching_data, callback_class, use_gpu)
|
972
987
|
|
973
988
|
n_splits = np.prod(list(splits.values()))
|
974
989
|
target_split = ", ".join(
|
@@ -980,6 +995,7 @@ def main():
|
|
980
995
|
f" [{matching_data.rotations.shape[0]} rotations]",
|
981
996
|
"Center Template": args.centering,
|
982
997
|
"Scramble Template": args.scramble_phases,
|
998
|
+
"Background Correction": args.background_correction,
|
983
999
|
"Invert Contrast": args.invert_target_contrast,
|
984
1000
|
"Extend Target Edges": args.pad_edges,
|
985
1001
|
"Interpolation Order": args.interpolation_order,
|
@@ -1020,7 +1036,6 @@ def main():
|
|
1020
1036
|
if args.ctf_file is not None or args.defocus is not None:
|
1021
1037
|
filter_args["CTF File"] = args.ctf_file
|
1022
1038
|
filter_args["Flip Phase"] = args.no_flip_phase
|
1023
|
-
filter_args["Correct Defocus"] = args.correct_defocus_gradient
|
1024
1039
|
|
1025
1040
|
filter_args = {k: v for k, v in filter_args.items() if v is not None}
|
1026
1041
|
if len(filter_args):
|
@@ -1042,7 +1057,7 @@ def main():
|
|
1042
1057
|
analyzer_args["acceptance_radius"] = args.orientations_uncertainty
|
1043
1058
|
analyzer_args["positions"] = args.orientations.translations
|
1044
1059
|
analyzer_args["rotations"] = euler_to_rotationmatrix(
|
1045
|
-
args.orientations.rotations
|
1060
|
+
args.orientations.rotations, seq="ZYZ"
|
1046
1061
|
)
|
1047
1062
|
|
1048
1063
|
print_block(
|
@@ -1062,7 +1077,7 @@ def main():
|
|
1062
1077
|
|
1063
1078
|
start = time()
|
1064
1079
|
print("Running Template Matching. This might take a while ...")
|
1065
|
-
candidates =
|
1080
|
+
candidates = match_exhaustive(
|
1066
1081
|
matching_data=matching_data,
|
1067
1082
|
job_schedule=schedule,
|
1068
1083
|
matching_score=matching_score,
|
@@ -1072,6 +1087,8 @@ def main():
|
|
1072
1087
|
target_splits=splits,
|
1073
1088
|
pad_target_edges=args.pad_edges,
|
1074
1089
|
interpolation_order=args.interpolation_order,
|
1090
|
+
match_projection=args.match_projection,
|
1091
|
+
background_correction=args.background_correction,
|
1075
1092
|
)
|
1076
1093
|
|
1077
1094
|
candidates = list(candidates) if candidates is not None else []
|