pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/match_template.py +28 -39
- {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/postprocess.py +35 -21
- {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/preprocessor_gui.py +95 -24
- pytme-0.3.1.post1.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/METADATA +5 -7
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/RECORD +55 -48
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +28 -39
- scripts/postprocess.py +35 -21
- scripts/preprocessor_gui.py +95 -24
- scripts/pytme_runner.py +644 -190
- scripts/refine_matches.py +156 -386
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_utils.py +18 -0
- tests/test_analyzer.py +2 -3
- tests/test_backends.py +3 -9
- tests/test_density.py +0 -1
- tests/test_extensions.py +0 -1
- tests/test_matching_utils.py +10 -60
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/_utils.py +4 -4
- tme/analyzer/aggregation.py +35 -15
- tme/analyzer/peaks.py +11 -10
- tme/backends/_jax_utils.py +26 -13
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +16 -55
- tme/backends/jax_backend.py +76 -37
- tme/backends/matching_backend.py +17 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +71 -65
- tme/backends/pytorch_backend.py +1 -26
- tme/density.py +2 -6
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/ctf.py +22 -21
- tme/filters/wedge.py +10 -7
- tme/mask.py +341 -0
- tme/matching_data.py +31 -19
- tme/matching_exhaustive.py +37 -47
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +229 -411
- tme/matching_utils.py +73 -422
- tme/memory.py +1 -1
- tme/orientations.py +13 -8
- tme/rotations.py +1 -1
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
- {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/preprocess.py +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/entry_points.txt +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -12,8 +12,8 @@ from sys import exit
|
|
12
12
|
from time import time
|
13
13
|
from typing import Tuple
|
14
14
|
from copy import deepcopy
|
15
|
-
from os.path import exists
|
16
15
|
from tempfile import gettempdir
|
16
|
+
from os.path import exists, abspath
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
|
@@ -576,7 +576,7 @@ def parse_args():
|
|
576
576
|
"'angles', or a single column file without header. Exposure will be taken from "
|
577
577
|
"the input file , if you are using a tab-separated file, the column names "
|
578
578
|
"'angles' and 'weights' need to be present. It is also possible to specify a "
|
579
|
-
"continuous wedge mask using e.g.,
|
579
|
+
"continuous wedge mask using e.g., 50,45.",
|
580
580
|
)
|
581
581
|
filter_group.add_argument(
|
582
582
|
"--tilt-weighting",
|
@@ -685,13 +685,6 @@ def parse_args():
|
|
685
685
|
help="Useful if the target does not have a well-defined bounding box. Will be "
|
686
686
|
"activated automatically if splitting is required to avoid boundary artifacts.",
|
687
687
|
)
|
688
|
-
performance_group.add_argument(
|
689
|
-
"--pad-filter",
|
690
|
-
action="store_true",
|
691
|
-
default=False,
|
692
|
-
help="Pad the template filter to the shape of the target. Useful for fast "
|
693
|
-
"oscilating filters to avoid aliasing effects.",
|
694
|
-
)
|
695
688
|
performance_group.add_argument(
|
696
689
|
"--interpolation-order",
|
697
690
|
required=False,
|
@@ -700,13 +693,6 @@ def parse_args():
|
|
700
693
|
help="Spline interpolation used for rotations. Defaults to 3, and 1 for jax "
|
701
694
|
"and pytorch backends.",
|
702
695
|
)
|
703
|
-
performance_group.add_argument(
|
704
|
-
"--use-mixed-precision",
|
705
|
-
action="store_true",
|
706
|
-
default=False,
|
707
|
-
help="Use float16 for real values operations where possible. Not supported "
|
708
|
-
"for jax backend.",
|
709
|
-
)
|
710
696
|
performance_group.add_argument(
|
711
697
|
"--use-memmap",
|
712
698
|
action="store_true",
|
@@ -742,6 +728,7 @@ def parse_args():
|
|
742
728
|
args.interpolation_order = 3
|
743
729
|
if args.backend in ("jax", "pytorch"):
|
744
730
|
args.interpolation_order = 1
|
731
|
+
args.reconstruction_interpolation_order = 1
|
745
732
|
|
746
733
|
if args.interpolation_order < 0:
|
747
734
|
args.interpolation_order = None
|
@@ -779,6 +766,14 @@ def parse_args():
|
|
779
766
|
)
|
780
767
|
args.orientations = orientations
|
781
768
|
|
769
|
+
args.target = abspath(args.target)
|
770
|
+
if args.target_mask is not None:
|
771
|
+
args.target_mask = abspath(args.target_mask)
|
772
|
+
|
773
|
+
args.template = abspath(args.template)
|
774
|
+
if args.template_mask is not None:
|
775
|
+
args.template_mask = abspath(args.template_mask)
|
776
|
+
|
782
777
|
return args
|
783
778
|
|
784
779
|
|
@@ -796,15 +791,23 @@ def main():
|
|
796
791
|
sampling_rate=target.sampling_rate,
|
797
792
|
)
|
798
793
|
|
794
|
+
if np.allclose(target.sampling_rate, 1):
|
795
|
+
warnings.warn(
|
796
|
+
"Target sampling rate is 1.0, which may indicate missing or incorrect "
|
797
|
+
"metadata. Verify that your target file contains proper sampling rate "
|
798
|
+
"information, as filters (CTF, BandPass) require accurate sampling rates "
|
799
|
+
"to function correctly."
|
800
|
+
)
|
801
|
+
|
799
802
|
if target.sampling_rate.size == template.sampling_rate.size:
|
800
803
|
if not np.allclose(
|
801
804
|
np.round(target.sampling_rate, 2), np.round(template.sampling_rate, 2)
|
802
805
|
):
|
803
|
-
|
804
|
-
f"
|
805
|
-
"
|
806
|
+
warnings.warn(
|
807
|
+
f"Sampling rate mismatch detected: target={target.sampling_rate} "
|
808
|
+
f"template={template.sampling_rate}. Proceeding with user-provided "
|
809
|
+
f"values. Make sure this is intentional. "
|
806
810
|
)
|
807
|
-
template = template.resample(target.sampling_rate, order=3)
|
808
811
|
|
809
812
|
template_mask = load_and_validate_mask(
|
810
813
|
mask_target=template, mask_path=args.template_mask
|
@@ -881,16 +884,13 @@ def main():
|
|
881
884
|
print("\n" + "-" * 80)
|
882
885
|
|
883
886
|
if args.scramble_phases:
|
884
|
-
template.data = scramble_phases(
|
885
|
-
template.data, noise_proportion=1.0, normalize_power=False
|
886
|
-
)
|
887
|
+
template.data = scramble_phases(template.data, noise_proportion=1.0)
|
887
888
|
|
888
889
|
callback_class = MaxScoreOverRotations
|
889
|
-
if args.peak_calling:
|
890
|
-
callback_class = PeakCallerMaximumFilter
|
891
|
-
|
892
890
|
if args.orientations is not None:
|
893
891
|
callback_class = MaxScoreOverRotationsConstrained
|
892
|
+
elif args.peak_calling:
|
893
|
+
callback_class = PeakCallerMaximumFilter
|
894
894
|
|
895
895
|
# Determine suitable backend for the selected operation
|
896
896
|
available_backends = be.available_backends()
|
@@ -939,16 +939,6 @@ def main():
|
|
939
939
|
args.use_gpu = False
|
940
940
|
be.change_backend("pytorch", device=device)
|
941
941
|
|
942
|
-
# TODO: Make the inverse casting from complex64 -> float 16 stable
|
943
|
-
# if args.use_mixed_precision:
|
944
|
-
# be.change_backend(
|
945
|
-
# backend_name=args.backend,
|
946
|
-
# float_dtype=be._array_backend.float16,
|
947
|
-
# complex_dtype=be._array_backend.complex64,
|
948
|
-
# int_dtype=be._array_backend.int16,
|
949
|
-
# device=device,
|
950
|
-
# )
|
951
|
-
|
952
942
|
available_memory = be.get_available_memory() * be.device_count()
|
953
943
|
if args.memory is None:
|
954
944
|
args.memory = int(args.memory_scaling * available_memory)
|
@@ -976,6 +966,8 @@ def main():
|
|
976
966
|
target_dim=target.metadata.get("batch_dimension", None),
|
977
967
|
template_dim=template.metadata.get("batch_dimension", None),
|
978
968
|
)
|
969
|
+
args.batch_dims = tuple(int(x) for x in np.where(matching_data._batch_mask)[0])
|
970
|
+
|
979
971
|
splits, schedule = compute_schedule(args, matching_data, callback_class)
|
980
972
|
|
981
973
|
n_splits = np.prod(list(splits.values()))
|
@@ -1004,7 +996,6 @@ def main():
|
|
1004
996
|
compute_options = {
|
1005
997
|
"Backend": be._BACKEND_REGISTRY[be._backend_name],
|
1006
998
|
"Compute Devices": f"CPU [{args.cores}], GPU [{gpus_used}]",
|
1007
|
-
"Use Mixed Precision": args.use_mixed_precision,
|
1008
999
|
"Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
|
1009
1000
|
"Temporary Directory": args.temp_directory,
|
1010
1001
|
"Target Splits": f"{target_split} [N={n_splits}]",
|
@@ -1025,7 +1016,6 @@ def main():
|
|
1025
1016
|
"Tilt Angles": args.tilt_angles,
|
1026
1017
|
"Tilt Weighting": args.tilt_weighting,
|
1027
1018
|
"Reconstruction Filter": args.reconstruction_filter,
|
1028
|
-
"Extend Filter Grid": args.pad_filter,
|
1029
1019
|
}
|
1030
1020
|
if args.ctf_file is not None or args.defocus is not None:
|
1031
1021
|
filter_args["CTF File"] = args.ctf_file
|
@@ -1081,7 +1071,6 @@ def main():
|
|
1081
1071
|
callback_class_args=analyzer_args,
|
1082
1072
|
target_splits=splits,
|
1083
1073
|
pad_target_edges=args.pad_edges,
|
1084
|
-
pad_template_filter=args.pad_filter,
|
1085
1074
|
interpolation_order=args.interpolation_order,
|
1086
1075
|
)
|
1087
1076
|
|
@@ -17,7 +17,7 @@ from scipy.special import erfcinv
|
|
17
17
|
|
18
18
|
from tme import Density, Structure, Orientations
|
19
19
|
from tme.cli import sanitize_name, print_block, print_entry
|
20
|
-
from tme.matching_utils import load_pickle,
|
20
|
+
from tme.matching_utils import load_pickle, center_slice, write_pickle
|
21
21
|
from tme.matching_optimization import create_score_object, optimize_match
|
22
22
|
from tme.rotations import euler_to_rotationmatrix, euler_from_rotationmatrix
|
23
23
|
from tme.analyzer import (
|
@@ -87,6 +87,11 @@ def parse_args():
|
|
87
87
|
help="Output prefix. Defaults to basename of first input. Extension is "
|
88
88
|
"added with respect to chosen output format.",
|
89
89
|
)
|
90
|
+
output_group.add_argument(
|
91
|
+
"--angles-clockwise",
|
92
|
+
action="store_true",
|
93
|
+
help="Report Euler angles in clockwise format expected by RELION.",
|
94
|
+
)
|
90
95
|
output_group.add_argument(
|
91
96
|
"--output-format",
|
92
97
|
choices=[
|
@@ -112,7 +117,7 @@ def parse_args():
|
|
112
117
|
peak_group.add_argument(
|
113
118
|
"--peak-caller",
|
114
119
|
choices=list(PEAK_CALLERS.keys()),
|
115
|
-
default="
|
120
|
+
default="PeakCallerMaximumFilter",
|
116
121
|
help="Peak caller for local maxima identification.",
|
117
122
|
)
|
118
123
|
peak_group.add_argument(
|
@@ -183,7 +188,7 @@ def parse_args():
|
|
183
188
|
)
|
184
189
|
additional_group.add_argument(
|
185
190
|
"--n-false-positives",
|
186
|
-
type=
|
191
|
+
type=float,
|
187
192
|
default=None,
|
188
193
|
required=False,
|
189
194
|
help="Number of accepted false-positives picks to determine minimum score.",
|
@@ -313,11 +318,7 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
|
|
313
318
|
data = load_matching_output(foreground)
|
314
319
|
scores, _, rotations, rotation_mapping, *_ = data
|
315
320
|
|
316
|
-
# We could normalize to unit sdev, but that might lead to unexpected
|
317
|
-
# results for flat background distributions
|
318
|
-
scores -= scores.mean()
|
319
321
|
indices = tuple(slice(0, x) for x in scores.shape)
|
320
|
-
|
321
322
|
indices_update = scores > scores_out[indices]
|
322
323
|
scores_out[indices][indices_update] = scores[indices_update]
|
323
324
|
|
@@ -364,9 +365,8 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
|
|
364
365
|
scores_norm = np.full(out_shape_norm, fill_value=0, dtype=np.float32)
|
365
366
|
for background in backgrounds:
|
366
367
|
data_norm = load_matching_output(background)
|
368
|
+
scores, _, rotations, rotation_mapping, *_ = data_norm
|
367
369
|
|
368
|
-
scores = data_norm[0]
|
369
|
-
scores -= scores.mean()
|
370
370
|
indices = tuple(slice(0, x) for x in scores.shape)
|
371
371
|
indices_update = scores > scores_norm[indices]
|
372
372
|
scores_norm[indices][indices_update] = scores[indices_update]
|
@@ -375,7 +375,9 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
|
|
375
375
|
update = tuple(slice(0, int(x)) for x in np.minimum(out_shape, scores.shape))
|
376
376
|
scores_out = np.full(out_shape, fill_value=0, dtype=np.float32)
|
377
377
|
scores_out[update] = data[0][update] - scores_norm[update]
|
378
|
-
scores_out[update]
|
378
|
+
scores_out[update] += scores_norm[update].mean()
|
379
|
+
|
380
|
+
# scores_out[update] = np.divide(scores_out[update], 1 - scores_norm[update])
|
379
381
|
scores_out = np.fmax(scores_out, 0, out=scores_out)
|
380
382
|
data[0] = scores_out
|
381
383
|
|
@@ -478,15 +480,21 @@ def main():
|
|
478
480
|
if orientations is None:
|
479
481
|
translations, rotations, scores, details = [], [], [], []
|
480
482
|
|
481
|
-
|
482
|
-
|
483
|
+
var = None
|
484
|
+
# Data processed by normalize_input is guaranteed to have this shape)
|
485
|
+
scores, _, rotation_array, rotation_mapping, *_ = data
|
486
|
+
if len(data) == 6:
|
487
|
+
scores, _, rotation_array, rotation_mapping, var, *_ = data
|
483
488
|
|
484
489
|
cropped_shape = np.subtract(
|
485
490
|
scores.shape, np.multiply(args.min_boundary_distance, 2)
|
486
491
|
).astype(int)
|
487
492
|
|
488
493
|
if args.min_boundary_distance > 0:
|
489
|
-
|
494
|
+
_scores = np.zeros_like(scores)
|
495
|
+
subset = center_slice(scores.shape, cropped_shape)
|
496
|
+
_scores[subset] = scores[subset]
|
497
|
+
scores = _scores
|
490
498
|
|
491
499
|
if args.n_false_positives is not None:
|
492
500
|
# Rickgauer et al. 2017
|
@@ -499,17 +507,20 @@ def main():
|
|
499
507
|
)
|
500
508
|
args.n_false_positives = max(args.n_false_positives, 1)
|
501
509
|
n_correlations = np.size(scores[cropped_slice]) * len(rotation_mapping)
|
510
|
+
std = np.std(scores[cropped_slice])
|
511
|
+
if var is not None:
|
512
|
+
std = np.asarray(np.sqrt(var)).reshape(())
|
513
|
+
|
502
514
|
minimum_score = np.multiply(
|
503
515
|
erfcinv(2 * args.n_false_positives / n_correlations),
|
504
|
-
np.sqrt(2) *
|
516
|
+
np.sqrt(2) * std,
|
505
517
|
)
|
506
|
-
print(f"Determined
|
507
|
-
|
508
|
-
args.min_score = minimum_score
|
518
|
+
print(f"Determined cutoff --min-score {minimum_score}.")
|
519
|
+
args.min_score = max(minimum_score, 0)
|
509
520
|
|
510
521
|
args.batch_dims = None
|
511
|
-
if hasattr(cli_args, "
|
512
|
-
args.batch_dims = cli_args.
|
522
|
+
if hasattr(cli_args, "batch_dims"):
|
523
|
+
args.batch_dims = cli_args.batch_dims
|
513
524
|
|
514
525
|
peak_caller_kwargs = {
|
515
526
|
"shape": scores.shape,
|
@@ -517,8 +528,8 @@ def main():
|
|
517
528
|
"min_distance": args.min_distance,
|
518
529
|
"min_boundary_distance": args.min_boundary_distance,
|
519
530
|
"batch_dims": args.batch_dims,
|
520
|
-
"
|
521
|
-
"
|
531
|
+
"min_score": args.min_score,
|
532
|
+
"max_score": args.max_score,
|
522
533
|
}
|
523
534
|
|
524
535
|
peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
|
@@ -611,6 +622,9 @@ def main():
|
|
611
622
|
orientations.rotations[index] = angles
|
612
623
|
orientations.scores[index] = score * -1
|
613
624
|
|
625
|
+
if args.angles_clockwise:
|
626
|
+
orientations.rotations *= -1
|
627
|
+
|
614
628
|
if args.output_format in ("orientations", "relion4", "relion5"):
|
615
629
|
file_format, extension = "text", "tsv"
|
616
630
|
|
@@ -19,13 +19,13 @@ from magicgui import widgets
|
|
19
19
|
from numpy.typing import NDArray
|
20
20
|
from napari.layers import Image
|
21
21
|
from scipy.fft import next_fast_len
|
22
|
-
from qtpy.QtWidgets import QFileDialog
|
23
22
|
from napari.utils.events import EventedList
|
23
|
+
from qtpy.QtWidgets import QFileDialog, QMessageBox
|
24
24
|
|
25
|
-
from tme.backends import backend
|
25
|
+
from tme.backends import backend as be
|
26
26
|
from tme.rotations import align_vectors
|
27
|
-
from tme import Preprocessor, Density, Orientations
|
28
27
|
from tme.matching_utils import create_mask, load_pickle
|
28
|
+
from tme import Preprocessor, Density, Orientations
|
29
29
|
from tme.filters import BandPassReconstructed, CTFReconstructed
|
30
30
|
|
31
31
|
preprocessor = Preprocessor()
|
@@ -69,7 +69,7 @@ def ctf_filter(
|
|
69
69
|
flip_phase: bool = False,
|
70
70
|
) -> NDArray:
|
71
71
|
fast_shape = [next_fast_len(x) for x in np.multiply(template.shape, 2)]
|
72
|
-
template_pad =
|
72
|
+
template_pad = be.topleft_pad(template, fast_shape)
|
73
73
|
template_ft = np.fft.rfftn(template_pad, s=template_pad.shape)
|
74
74
|
ctf = CTFReconstructed(
|
75
75
|
shape=fast_shape,
|
@@ -85,7 +85,7 @@ def ctf_filter(
|
|
85
85
|
)
|
86
86
|
np.multiply(template_ft, ctf()["data"], out=template_ft)
|
87
87
|
template_pad = np.fft.irfftn(template_ft, s=template_pad.shape).real
|
88
|
-
template =
|
88
|
+
template = be.topleft_pad(template_pad, template.shape)
|
89
89
|
return template
|
90
90
|
|
91
91
|
|
@@ -392,7 +392,13 @@ class FilterWidget(widgets.Container):
|
|
392
392
|
metadata = selected_layer_metadata.copy()
|
393
393
|
if "filter_parameters" not in metadata:
|
394
394
|
metadata["filter_parameters"] = []
|
395
|
-
|
395
|
+
|
396
|
+
payload = {filter_name: kwargs.copy()}
|
397
|
+
if isinstance(metadata["filter_parameters"], dict):
|
398
|
+
metadata["filter_parameters"].update(payload)
|
399
|
+
else:
|
400
|
+
metadata["filter_parameters"].append(payload)
|
401
|
+
|
396
402
|
metadata["used_filter"] = filter_name
|
397
403
|
new_layer.metadata = metadata
|
398
404
|
|
@@ -450,7 +456,29 @@ def box_mask(
|
|
450
456
|
mask_type="box",
|
451
457
|
shape=template.shape,
|
452
458
|
center=(center_x, center_y, center_z),
|
453
|
-
|
459
|
+
size=(height_x, height_y, height_z),
|
460
|
+
sigma_decay=sigma_decay,
|
461
|
+
)
|
462
|
+
|
463
|
+
|
464
|
+
def membrane_mask(
|
465
|
+
template: NDArray,
|
466
|
+
symmetry_axis: int,
|
467
|
+
center_x: float,
|
468
|
+
center_y: float,
|
469
|
+
center_z: float,
|
470
|
+
radius: float,
|
471
|
+
thickness: float = 1,
|
472
|
+
separation: float = 3,
|
473
|
+
sigma_decay: float = 0,
|
474
|
+
**kwargs,
|
475
|
+
) -> NDArray:
|
476
|
+
return create_mask(
|
477
|
+
mask_type="membrane",
|
478
|
+
shape=template.shape,
|
479
|
+
radius=radius,
|
480
|
+
thickness=thickness,
|
481
|
+
separation=separation,
|
454
482
|
sigma_decay=sigma_decay,
|
455
483
|
)
|
456
484
|
|
@@ -471,7 +499,7 @@ def tube_mask(
|
|
471
499
|
mask_type="tube",
|
472
500
|
shape=template.shape,
|
473
501
|
symmetry_axis=symmetry_axis,
|
474
|
-
|
502
|
+
center=(center_x, center_y, center_z),
|
475
503
|
inner_radius=inner_radius,
|
476
504
|
outer_radius=outer_radius,
|
477
505
|
height=height,
|
@@ -584,6 +612,7 @@ class MaskWidget(widgets.Container):
|
|
584
612
|
"Ellipsoid": ellipsod_mask,
|
585
613
|
"Tube": tube_mask,
|
586
614
|
"Box": box_mask,
|
615
|
+
"Membrane": membrane_mask,
|
587
616
|
"Wedge": wedge_mask,
|
588
617
|
"Threshold": threshold_mask,
|
589
618
|
"Lowpass": lowpass_mask,
|
@@ -817,7 +846,7 @@ class AlignmentWidget(widgets.Container):
|
|
817
846
|
principal_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]
|
818
847
|
|
819
848
|
rotation_matrix = align_vectors(principal_eigenvector, alignment_axis)
|
820
|
-
rotated_data, _ =
|
849
|
+
rotated_data, _ = be.rigid_transform(
|
821
850
|
arr=active_layer.data,
|
822
851
|
rotation_matrix=rotation_matrix,
|
823
852
|
use_geometric_center=False,
|
@@ -953,7 +982,6 @@ class PointCloudWidget(widgets.Container):
|
|
953
982
|
if not isinstance(layer, napari.layers.Points):
|
954
983
|
continue
|
955
984
|
|
956
|
-
layer.face_color = "white"
|
957
985
|
if event == "Label":
|
958
986
|
if len(layer.properties.get("detail", ())) == 0:
|
959
987
|
continue
|
@@ -970,9 +998,7 @@ class PointCloudWidget(widgets.Container):
|
|
970
998
|
layer.face_color = "score_scaled"
|
971
999
|
layer.face_colormap = "turbo"
|
972
1000
|
layer.face_color_mode = "colormap"
|
973
|
-
|
974
1001
|
layer.refresh_colors()
|
975
|
-
|
976
1002
|
return None
|
977
1003
|
|
978
1004
|
def _set_positive(self, event):
|
@@ -1140,9 +1166,18 @@ class MatchingWidget(widgets.Container):
|
|
1140
1166
|
self.viewer = viewer
|
1141
1167
|
self.dataframes = {}
|
1142
1168
|
|
1169
|
+
option_container = widgets.Container(layout="horizontal")
|
1170
|
+
self.load_target_checkbox = widgets.CheckBox(text="Load Target", value=False)
|
1171
|
+
self.load_rotations_checkbox = widgets.CheckBox(
|
1172
|
+
text="Load Rotations", value=False
|
1173
|
+
)
|
1174
|
+
option_container.append(self.load_target_checkbox)
|
1175
|
+
option_container.append(self.load_rotations_checkbox)
|
1176
|
+
|
1143
1177
|
self.import_button = widgets.PushButton(name="Import", text="Import Pickle")
|
1144
1178
|
self.import_button.clicked.connect(self._get_load_path)
|
1145
1179
|
|
1180
|
+
self.append(option_container)
|
1146
1181
|
self.append(self.import_button)
|
1147
1182
|
|
1148
1183
|
def _get_load_path(self, event):
|
@@ -1150,7 +1185,7 @@ class MatchingWidget(widgets.Container):
|
|
1150
1185
|
self.native,
|
1151
1186
|
"Open Pickle File...",
|
1152
1187
|
"",
|
1153
|
-
"Pickle Files (*.pickle);;All Files (*)",
|
1188
|
+
"Pickle Files (*.pickle *pickle.gz);;All Files (*)",
|
1154
1189
|
)
|
1155
1190
|
if filename:
|
1156
1191
|
self._load_data(filename)
|
@@ -1159,14 +1194,35 @@ class MatchingWidget(widgets.Container):
|
|
1159
1194
|
data = load_pickle(filename)
|
1160
1195
|
|
1161
1196
|
fname = basename(filename).replace(".pickle", "")
|
1197
|
+
|
1198
|
+
if self.load_target_checkbox.value:
|
1199
|
+
try:
|
1200
|
+
target = Density.from_file(data[-1][-1].target)
|
1201
|
+
_ = self.viewer.add_image(
|
1202
|
+
data=target.data,
|
1203
|
+
name=f"{fname}_target",
|
1204
|
+
metadata={
|
1205
|
+
"origin": target.origin,
|
1206
|
+
"sampling_rate": target.sampling_rate,
|
1207
|
+
},
|
1208
|
+
)
|
1209
|
+
except Exception as e:
|
1210
|
+
msg = QMessageBox(self.native)
|
1211
|
+
msg.setIcon(QMessageBox.Warning)
|
1212
|
+
msg.setWindowTitle("Loading Error")
|
1213
|
+
msg.setText(str(e))
|
1214
|
+
msg.setStandardButtons(QMessageBox.Ok)
|
1215
|
+
msg.exec_()
|
1216
|
+
|
1162
1217
|
if data[0].ndim == data[2].ndim:
|
1163
1218
|
metadata = {"origin": data[-1][1], "sampling_rate": data[-1][2]}
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1219
|
+
if self.load_rotations_checkbox.value:
|
1220
|
+
_ = self.viewer.add_image(
|
1221
|
+
data=data[2],
|
1222
|
+
name=f"{fname}_rotations",
|
1223
|
+
colormap="orange",
|
1224
|
+
metadata=metadata,
|
1225
|
+
)
|
1170
1226
|
_ = self.viewer.add_image(
|
1171
1227
|
data=data[0],
|
1172
1228
|
name=f"{fname}_scores",
|
@@ -1174,11 +1230,8 @@ class MatchingWidget(widgets.Container):
|
|
1174
1230
|
metadata=metadata,
|
1175
1231
|
)
|
1176
1232
|
return None
|
1177
|
-
detail = np.zeros_like(data[2])
|
1178
|
-
else:
|
1179
|
-
detail = data[3]
|
1180
1233
|
|
1181
|
-
point_properties = {"score": data[2], "detail":
|
1234
|
+
point_properties = {"score": data[2], "detail": data[3]}
|
1182
1235
|
point_properties["score_scaled"] = np.log1p(
|
1183
1236
|
point_properties["score"] - point_properties["score"].min()
|
1184
1237
|
)
|
@@ -1191,8 +1244,26 @@ class MatchingWidget(widgets.Container):
|
|
1191
1244
|
)
|
1192
1245
|
|
1193
1246
|
|
1247
|
+
class CustomNapariViewer(napari.Viewer):
|
1248
|
+
"""
|
1249
|
+
Custom viewer to ensure 3D image layers are by default shown as xy projection.
|
1250
|
+
"""
|
1251
|
+
|
1252
|
+
def add_image(self, data, **kwargs):
|
1253
|
+
viewer_ndim = len(self.dims.order)
|
1254
|
+
layer = super().add_image(data, **kwargs)
|
1255
|
+
|
1256
|
+
try:
|
1257
|
+
# Set to xy view the first time data is opened
|
1258
|
+
if viewer_ndim != 3 and data.ndim == 3:
|
1259
|
+
self.dims.order = (2, 0, 1)
|
1260
|
+
except Exception:
|
1261
|
+
pass
|
1262
|
+
return layer
|
1263
|
+
|
1264
|
+
|
1194
1265
|
def main():
|
1195
|
-
viewer =
|
1266
|
+
viewer = CustomNapariViewer()
|
1196
1267
|
|
1197
1268
|
filter_widget = FilterWidget(preprocessor, viewer)
|
1198
1269
|
mask_widget = MaskWidget(viewer)
|