pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.dev20250731__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. pytme-0.3.1.dev20250731.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/match_template.py +30 -41
  3. {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/postprocess.py +35 -21
  4. {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocessor_gui.py +96 -24
  5. pytme-0.3.1.dev20250731.data/scripts/pytme_runner.py +1223 -0
  6. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/METADATA +5 -7
  7. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/RECORD +59 -49
  8. scripts/estimate_ram_usage.py +97 -0
  9. scripts/extract_candidates.py +118 -99
  10. scripts/match_template.py +30 -41
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +35 -21
  13. scripts/preprocessor_gui.py +96 -24
  14. scripts/pytme_runner.py +644 -190
  15. scripts/refine_matches.py +158 -390
  16. tests/data/.DS_Store +0 -0
  17. tests/data/Blurring/.DS_Store +0 -0
  18. tests/data/Maps/.DS_Store +0 -0
  19. tests/data/Raw/.DS_Store +0 -0
  20. tests/data/Structures/.DS_Store +0 -0
  21. tests/preprocessing/test_utils.py +18 -0
  22. tests/test_analyzer.py +2 -3
  23. tests/test_backends.py +3 -9
  24. tests/test_density.py +0 -1
  25. tests/test_extensions.py +0 -1
  26. tests/test_matching_utils.py +10 -60
  27. tests/test_orientations.py +0 -12
  28. tests/test_rotations.py +1 -1
  29. tme/__version__.py +1 -1
  30. tme/analyzer/_utils.py +4 -4
  31. tme/analyzer/aggregation.py +35 -15
  32. tme/analyzer/peaks.py +11 -10
  33. tme/backends/_jax_utils.py +64 -18
  34. tme/backends/_numpyfftw_utils.py +270 -0
  35. tme/backends/cupy_backend.py +16 -55
  36. tme/backends/jax_backend.py +79 -40
  37. tme/backends/matching_backend.py +17 -51
  38. tme/backends/mlx_backend.py +1 -27
  39. tme/backends/npfftw_backend.py +71 -65
  40. tme/backends/pytorch_backend.py +1 -26
  41. tme/density.py +58 -5
  42. tme/extensions.cpython-311-darwin.so +0 -0
  43. tme/filters/ctf.py +22 -21
  44. tme/filters/wedge.py +10 -7
  45. tme/mask.py +341 -0
  46. tme/matching_data.py +31 -19
  47. tme/matching_exhaustive.py +37 -47
  48. tme/matching_optimization.py +2 -1
  49. tme/matching_scores.py +229 -411
  50. tme/matching_utils.py +73 -422
  51. tme/memory.py +1 -1
  52. tme/orientations.py +24 -13
  53. tme/rotations.py +1 -1
  54. pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
  55. {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/estimate_memory_usage.py +0 -0
  56. {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocess.py +0 -0
  57. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/WHEEL +0 -0
  58. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/entry_points.txt +0 -0
  59. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/licenses/LICENSE +0 -0
  60. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,97 @@
1
+ #!python
2
+ """ Estimate RAM requirements for template matching jobs.
3
+
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
+
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ """
8
+ import numpy as np
9
+ import argparse
10
+ from tme import Density
11
+ from tme.matching_utils import estimate_ram_usage
12
+ from tme.matching_exhaustive import MATCHING_EXHAUSTIVE_REGISTER
13
+
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(
17
+ description="Estimate RAM usage for template matching."
18
+ )
19
+ parser.add_argument(
20
+ "-m",
21
+ "--target",
22
+ dest="target",
23
+ type=str,
24
+ required=True,
25
+ help="Path to a target in CCP4/MRC format.",
26
+ )
27
+ parser.add_argument(
28
+ "-i",
29
+ "--template",
30
+ dest="template",
31
+ type=str,
32
+ required=True,
33
+ help="Path to a template in PDB/MMCIF or CCP4/MRC format.",
34
+ )
35
+ parser.add_argument(
36
+ "--matching_method",
37
+ required=False,
38
+ default=None,
39
+ help="Analyzer method to use.",
40
+ )
41
+ parser.add_argument(
42
+ "-s",
43
+ dest="score",
44
+ type=str,
45
+ default="FLCSphericalMask",
46
+ help="Template matching scoring function.",
47
+ choices=MATCHING_EXHAUSTIVE_REGISTER.keys(),
48
+ )
49
+ parser.add_argument(
50
+ "--ncores", type=int, help="Number of cores for parallelization.", required=True
51
+ )
52
+ parser.add_argument(
53
+ "--no_edge_padding",
54
+ dest="no_edge_padding",
55
+ action="store_true",
56
+ default=False,
57
+ help="Whether to pad the edges of the target. This is useful, if the target"
58
+ " has a well defined bounding box, e.g. a density map.",
59
+ )
60
+ parser.add_argument(
61
+ "--no_fourier_padding",
62
+ dest="no_fourier_padding",
63
+ action="store_true",
64
+ default=False,
65
+ help="Whether input arrays should be zero-padded to the full convolution shape"
66
+ " for numerical stability. When working with very large targets such as"
67
+ " tomograms it is safe to use this flag and benefit from the performance gain.",
68
+ )
69
+ args = parser.parse_args()
70
+ return args
71
+
72
+
73
+ def main():
74
+ args = parse_args()
75
+ target = Density.from_file(args.target)
76
+ template = Density.from_file(args.template)
77
+
78
+ target_box = target.shape
79
+ if not args.no_edge_padding:
80
+ target_box = np.add(target_box, template.shape)
81
+
82
+ template_box = template.shape
83
+ if args.no_fourier_padding:
84
+ template_box = np.ones(len(template_box), dtype=int)
85
+
86
+ result = estimate_ram_usage(
87
+ shape1=target_box,
88
+ shape2=template_box,
89
+ matching_method=args.score,
90
+ ncores=args.ncores,
91
+ analyzer_method="MaxScoreOverRotations",
92
+ )
93
+ print(result)
94
+
95
+
96
+ if __name__ == "__main__":
97
+ main()
@@ -12,8 +12,8 @@ from sys import exit
12
12
  from time import time
13
13
  from typing import Tuple
14
14
  from copy import deepcopy
15
- from os.path import exists
16
15
  from tempfile import gettempdir
16
+ from os.path import exists, abspath
17
17
 
18
18
  import numpy as np
19
19
 
@@ -359,8 +359,8 @@ def parse_args():
359
359
  "--invert-target-contrast",
360
360
  action="store_true",
361
361
  default=False,
362
- help="Invert the target's contrast for cases where templates to-be-matched have "
363
- "negative values, e.g. tomograms.",
362
+ help="Invert the target contrast. Useful for matching on tomograms if the "
363
+ "template has not been inverted.",
364
364
  )
365
365
  io_group.add_argument(
366
366
  "--scramble-phases",
@@ -576,7 +576,7 @@ def parse_args():
576
576
  "'angles', or a single column file without header. Exposure will be taken from "
577
577
  "the input file , if you are using a tab-separated file, the column names "
578
578
  "'angles' and 'weights' need to be present. It is also possible to specify a "
579
- "continuous wedge mask using e.g., -50,45.",
579
+ "continuous wedge mask using e.g., 50,45.",
580
580
  )
581
581
  filter_group.add_argument(
582
582
  "--tilt-weighting",
@@ -685,13 +685,6 @@ def parse_args():
685
685
  help="Useful if the target does not have a well-defined bounding box. Will be "
686
686
  "activated automatically if splitting is required to avoid boundary artifacts.",
687
687
  )
688
- performance_group.add_argument(
689
- "--pad-filter",
690
- action="store_true",
691
- default=False,
692
- help="Pad the template filter to the shape of the target. Useful for fast "
693
- "oscilating filters to avoid aliasing effects.",
694
- )
695
688
  performance_group.add_argument(
696
689
  "--interpolation-order",
697
690
  required=False,
@@ -700,13 +693,6 @@ def parse_args():
700
693
  help="Spline interpolation used for rotations. Defaults to 3, and 1 for jax "
701
694
  "and pytorch backends.",
702
695
  )
703
- performance_group.add_argument(
704
- "--use-mixed-precision",
705
- action="store_true",
706
- default=False,
707
- help="Use float16 for real values operations where possible. Not supported "
708
- "for jax backend.",
709
- )
710
696
  performance_group.add_argument(
711
697
  "--use-memmap",
712
698
  action="store_true",
@@ -742,6 +728,7 @@ def parse_args():
742
728
  args.interpolation_order = 3
743
729
  if args.backend in ("jax", "pytorch"):
744
730
  args.interpolation_order = 1
731
+ args.reconstruction_interpolation_order = 1
745
732
 
746
733
  if args.interpolation_order < 0:
747
734
  args.interpolation_order = None
@@ -779,6 +766,14 @@ def parse_args():
779
766
  )
780
767
  args.orientations = orientations
781
768
 
769
+ args.target = abspath(args.target)
770
+ if args.target_mask is not None:
771
+ args.target_mask = abspath(args.target_mask)
772
+
773
+ args.template = abspath(args.template)
774
+ if args.template_mask is not None:
775
+ args.template_mask = abspath(args.template_mask)
776
+
782
777
  return args
783
778
 
784
779
 
@@ -796,15 +791,23 @@ def main():
796
791
  sampling_rate=target.sampling_rate,
797
792
  )
798
793
 
794
+ if np.allclose(target.sampling_rate, 1):
795
+ warnings.warn(
796
+ "Target sampling rate is 1.0, which may indicate missing or incorrect "
797
+ "metadata. Verify that your target file contains proper sampling rate "
798
+ "information, as filters (CTF, BandPass) require accurate sampling rates "
799
+ "to function correctly."
800
+ )
801
+
799
802
  if target.sampling_rate.size == template.sampling_rate.size:
800
803
  if not np.allclose(
801
804
  np.round(target.sampling_rate, 2), np.round(template.sampling_rate, 2)
802
805
  ):
803
- print(
804
- f"Resampling template to {target.sampling_rate}. "
805
- "Consider providing a template with the same sampling rate as the target."
806
+ warnings.warn(
807
+ f"Sampling rate mismatch detected: target={target.sampling_rate} "
808
+ f"template={template.sampling_rate}. Proceeding with user-provided "
809
+ f"values. Make sure this is intentional. "
806
810
  )
807
- template = template.resample(target.sampling_rate, order=3)
808
811
 
809
812
  template_mask = load_and_validate_mask(
810
813
  mask_target=template, mask_path=args.template_mask
@@ -881,16 +884,13 @@ def main():
881
884
  print("\n" + "-" * 80)
882
885
 
883
886
  if args.scramble_phases:
884
- template.data = scramble_phases(
885
- template.data, noise_proportion=1.0, normalize_power=False
886
- )
887
+ template.data = scramble_phases(template.data, noise_proportion=1.0)
887
888
 
888
889
  callback_class = MaxScoreOverRotations
889
- if args.peak_calling:
890
- callback_class = PeakCallerMaximumFilter
891
-
892
890
  if args.orientations is not None:
893
891
  callback_class = MaxScoreOverRotationsConstrained
892
+ elif args.peak_calling:
893
+ callback_class = PeakCallerMaximumFilter
894
894
 
895
895
  # Determine suitable backend for the selected operation
896
896
  available_backends = be.available_backends()
@@ -939,16 +939,6 @@ def main():
939
939
  args.use_gpu = False
940
940
  be.change_backend("pytorch", device=device)
941
941
 
942
- # TODO: Make the inverse casting from complex64 -> float 16 stable
943
- # if args.use_mixed_precision:
944
- # be.change_backend(
945
- # backend_name=args.backend,
946
- # float_dtype=be._array_backend.float16,
947
- # complex_dtype=be._array_backend.complex64,
948
- # int_dtype=be._array_backend.int16,
949
- # device=device,
950
- # )
951
-
952
942
  available_memory = be.get_available_memory() * be.device_count()
953
943
  if args.memory is None:
954
944
  args.memory = int(args.memory_scaling * available_memory)
@@ -976,6 +966,8 @@ def main():
976
966
  target_dim=target.metadata.get("batch_dimension", None),
977
967
  template_dim=template.metadata.get("batch_dimension", None),
978
968
  )
969
+ args.batch_dims = tuple(int(x) for x in np.where(matching_data._batch_mask)[0])
970
+
979
971
  splits, schedule = compute_schedule(args, matching_data, callback_class)
980
972
 
981
973
  n_splits = np.prod(list(splits.values()))
@@ -1004,7 +996,6 @@ def main():
1004
996
  compute_options = {
1005
997
  "Backend": be._BACKEND_REGISTRY[be._backend_name],
1006
998
  "Compute Devices": f"CPU [{args.cores}], GPU [{gpus_used}]",
1007
- "Use Mixed Precision": args.use_mixed_precision,
1008
999
  "Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
1009
1000
  "Temporary Directory": args.temp_directory,
1010
1001
  "Target Splits": f"{target_split} [N={n_splits}]",
@@ -1025,7 +1016,6 @@ def main():
1025
1016
  "Tilt Angles": args.tilt_angles,
1026
1017
  "Tilt Weighting": args.tilt_weighting,
1027
1018
  "Reconstruction Filter": args.reconstruction_filter,
1028
- "Extend Filter Grid": args.pad_filter,
1029
1019
  }
1030
1020
  if args.ctf_file is not None or args.defocus is not None:
1031
1021
  filter_args["CTF File"] = args.ctf_file
@@ -1081,7 +1071,6 @@ def main():
1081
1071
  callback_class_args=analyzer_args,
1082
1072
  target_splits=splits,
1083
1073
  pad_target_edges=args.pad_edges,
1084
- pad_template_filter=args.pad_filter,
1085
1074
  interpolation_order=args.interpolation_order,
1086
1075
  )
1087
1076
 
@@ -17,7 +17,7 @@ from scipy.special import erfcinv
17
17
 
18
18
  from tme import Density, Structure, Orientations
19
19
  from tme.cli import sanitize_name, print_block, print_entry
20
- from tme.matching_utils import load_pickle, centered_mask, write_pickle
20
+ from tme.matching_utils import load_pickle, center_slice, write_pickle
21
21
  from tme.matching_optimization import create_score_object, optimize_match
22
22
  from tme.rotations import euler_to_rotationmatrix, euler_from_rotationmatrix
23
23
  from tme.analyzer import (
@@ -87,6 +87,11 @@ def parse_args():
87
87
  help="Output prefix. Defaults to basename of first input. Extension is "
88
88
  "added with respect to chosen output format.",
89
89
  )
90
+ output_group.add_argument(
91
+ "--angles-clockwise",
92
+ action="store_true",
93
+ help="Report Euler angles in clockwise format expected by RELION.",
94
+ )
90
95
  output_group.add_argument(
91
96
  "--output-format",
92
97
  choices=[
@@ -112,7 +117,7 @@ def parse_args():
112
117
  peak_group.add_argument(
113
118
  "--peak-caller",
114
119
  choices=list(PEAK_CALLERS.keys()),
115
- default="PeakCallerScipy",
120
+ default="PeakCallerMaximumFilter",
116
121
  help="Peak caller for local maxima identification.",
117
122
  )
118
123
  peak_group.add_argument(
@@ -183,7 +188,7 @@ def parse_args():
183
188
  )
184
189
  additional_group.add_argument(
185
190
  "--n-false-positives",
186
- type=int,
191
+ type=float,
187
192
  default=None,
188
193
  required=False,
189
194
  help="Number of accepted false-positives picks to determine minimum score.",
@@ -313,11 +318,7 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
313
318
  data = load_matching_output(foreground)
314
319
  scores, _, rotations, rotation_mapping, *_ = data
315
320
 
316
- # We could normalize to unit sdev, but that might lead to unexpected
317
- # results for flat background distributions
318
- scores -= scores.mean()
319
321
  indices = tuple(slice(0, x) for x in scores.shape)
320
-
321
322
  indices_update = scores > scores_out[indices]
322
323
  scores_out[indices][indices_update] = scores[indices_update]
323
324
 
@@ -364,9 +365,8 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
364
365
  scores_norm = np.full(out_shape_norm, fill_value=0, dtype=np.float32)
365
366
  for background in backgrounds:
366
367
  data_norm = load_matching_output(background)
368
+ scores, _, rotations, rotation_mapping, *_ = data_norm
367
369
 
368
- scores = data_norm[0]
369
- scores -= scores.mean()
370
370
  indices = tuple(slice(0, x) for x in scores.shape)
371
371
  indices_update = scores > scores_norm[indices]
372
372
  scores_norm[indices][indices_update] = scores[indices_update]
@@ -375,8 +375,10 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
375
375
  update = tuple(slice(0, int(x)) for x in np.minimum(out_shape, scores.shape))
376
376
  scores_out = np.full(out_shape, fill_value=0, dtype=np.float32)
377
377
  scores_out[update] = data[0][update] - scores_norm[update]
378
- scores_out[update] = np.divide(scores_out[update], 1 - scores_norm[update])
379
378
  scores_out = np.fmax(scores_out, 0, out=scores_out)
379
+ scores_out[update] += scores_norm[update].mean()
380
+
381
+ # scores_out[update] = np.divide(scores_out[update], 1 - scores_norm[update])
380
382
  data[0] = scores_out
381
383
 
382
384
  fg, bg = simple_stats(data[0]), simple_stats(scores_norm)
@@ -478,15 +480,21 @@ def main():
478
480
  if orientations is None:
479
481
  translations, rotations, scores, details = [], [], [], []
480
482
 
481
- # Data processed by normalize_input is guaranteed to have this shape
482
- scores, offset, rotation_array, rotation_mapping, meta = data
483
+ var = None
484
+ # Data processed by normalize_input is guaranteed to have this shape)
485
+ scores, _, rotation_array, rotation_mapping, *_ = data
486
+ if len(data) == 6:
487
+ scores, _, rotation_array, rotation_mapping, var, *_ = data
483
488
 
484
489
  cropped_shape = np.subtract(
485
490
  scores.shape, np.multiply(args.min_boundary_distance, 2)
486
491
  ).astype(int)
487
492
 
488
493
  if args.min_boundary_distance > 0:
489
- scores = centered_mask(scores, new_shape=cropped_shape)
494
+ _scores = np.zeros_like(scores)
495
+ subset = center_slice(scores.shape, cropped_shape)
496
+ _scores[subset] = scores[subset]
497
+ scores = _scores
490
498
 
491
499
  if args.n_false_positives is not None:
492
500
  # Rickgauer et al. 2017
@@ -499,17 +507,20 @@ def main():
499
507
  )
500
508
  args.n_false_positives = max(args.n_false_positives, 1)
501
509
  n_correlations = np.size(scores[cropped_slice]) * len(rotation_mapping)
510
+ std = np.std(scores[cropped_slice])
511
+ if var is not None:
512
+ std = np.asarray(np.sqrt(var)).reshape(())
513
+
502
514
  minimum_score = np.multiply(
503
515
  erfcinv(2 * args.n_false_positives / n_correlations),
504
- np.sqrt(2) * np.std(scores[cropped_slice]),
516
+ np.sqrt(2) * std,
505
517
  )
506
- print(f"Determined minimum score cutoff: {minimum_score}.")
507
- minimum_score = max(minimum_score, 0)
508
- args.min_score = minimum_score
518
+ print(f"Determined cutoff --min-score {minimum_score}.")
519
+ args.min_score = max(minimum_score, 0)
509
520
 
510
521
  args.batch_dims = None
511
- if hasattr(cli_args, "target_batch"):
512
- args.batch_dims = cli_args.target_batch
522
+ if hasattr(cli_args, "batch_dims"):
523
+ args.batch_dims = cli_args.batch_dims
513
524
 
514
525
  peak_caller_kwargs = {
515
526
  "shape": scores.shape,
@@ -517,8 +528,8 @@ def main():
517
528
  "min_distance": args.min_distance,
518
529
  "min_boundary_distance": args.min_boundary_distance,
519
530
  "batch_dims": args.batch_dims,
520
- "minimum_score": args.min_score,
521
- "maximum_score": args.max_score,
531
+ "min_score": args.min_score,
532
+ "max_score": args.max_score,
522
533
  }
523
534
 
524
535
  peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
@@ -611,6 +622,9 @@ def main():
611
622
  orientations.rotations[index] = angles
612
623
  orientations.scores[index] = score * -1
613
624
 
625
+ if args.angles_clockwise:
626
+ orientations.rotations *= -1
627
+
614
628
  if args.output_format in ("orientations", "relion4", "relion5"):
615
629
  file_format, extension = "text", "tsv"
616
630
 
@@ -19,13 +19,13 @@ from magicgui import widgets
19
19
  from numpy.typing import NDArray
20
20
  from napari.layers import Image
21
21
  from scipy.fft import next_fast_len
22
- from qtpy.QtWidgets import QFileDialog
23
22
  from napari.utils.events import EventedList
23
+ from qtpy.QtWidgets import QFileDialog, QMessageBox
24
24
 
25
- from tme.backends import backend
25
+ from tme.backends import backend as be
26
26
  from tme.rotations import align_vectors
27
- from tme import Preprocessor, Density, Orientations
28
27
  from tme.matching_utils import create_mask, load_pickle
28
+ from tme import Preprocessor, Density, Orientations
29
29
  from tme.filters import BandPassReconstructed, CTFReconstructed
30
30
 
31
31
  preprocessor = Preprocessor()
@@ -69,7 +69,7 @@ def ctf_filter(
69
69
  flip_phase: bool = False,
70
70
  ) -> NDArray:
71
71
  fast_shape = [next_fast_len(x) for x in np.multiply(template.shape, 2)]
72
- template_pad = backend.topleft_pad(template, fast_shape)
72
+ template_pad = be.topleft_pad(template, fast_shape)
73
73
  template_ft = np.fft.rfftn(template_pad, s=template_pad.shape)
74
74
  ctf = CTFReconstructed(
75
75
  shape=fast_shape,
@@ -85,7 +85,7 @@ def ctf_filter(
85
85
  )
86
86
  np.multiply(template_ft, ctf()["data"], out=template_ft)
87
87
  template_pad = np.fft.irfftn(template_ft, s=template_pad.shape).real
88
- template = backend.topleft_pad(template_pad, template.shape)
88
+ template = be.topleft_pad(template_pad, template.shape)
89
89
  return template
90
90
 
91
91
 
@@ -392,7 +392,13 @@ class FilterWidget(widgets.Container):
392
392
  metadata = selected_layer_metadata.copy()
393
393
  if "filter_parameters" not in metadata:
394
394
  metadata["filter_parameters"] = []
395
- metadata["filter_parameters"].append({filter_name: kwargs.copy()})
395
+
396
+ payload = {filter_name: kwargs.copy()}
397
+ if isinstance(metadata["filter_parameters"], dict):
398
+ metadata["filter_parameters"].update(payload)
399
+ else:
400
+ metadata["filter_parameters"].append(payload)
401
+
396
402
  metadata["used_filter"] = filter_name
397
403
  new_layer.metadata = metadata
398
404
 
@@ -450,7 +456,30 @@ def box_mask(
450
456
  mask_type="box",
451
457
  shape=template.shape,
452
458
  center=(center_x, center_y, center_z),
453
- height=(height_x, height_y, height_z),
459
+ size=(height_x, height_y, height_z),
460
+ sigma_decay=sigma_decay,
461
+ )
462
+
463
+
464
+ def membrane_mask(
465
+ template: NDArray,
466
+ symmetry_axis: int,
467
+ center_x: float,
468
+ center_y: float,
469
+ center_z: float,
470
+ radius: float,
471
+ thickness: float = 1,
472
+ separation: float = 3,
473
+ sigma_decay: float = 0,
474
+ **kwargs,
475
+ ) -> NDArray:
476
+ return create_mask(
477
+ center=(center_x, center_y, center_z),
478
+ mask_type="membrane",
479
+ shape=template.shape,
480
+ radius=radius,
481
+ thickness=thickness,
482
+ separation=separation,
454
483
  sigma_decay=sigma_decay,
455
484
  )
456
485
 
@@ -471,7 +500,7 @@ def tube_mask(
471
500
  mask_type="tube",
472
501
  shape=template.shape,
473
502
  symmetry_axis=symmetry_axis,
474
- base_center=(center_x, center_y, center_z),
503
+ center=(center_x, center_y, center_z),
475
504
  inner_radius=inner_radius,
476
505
  outer_radius=outer_radius,
477
506
  height=height,
@@ -584,6 +613,7 @@ class MaskWidget(widgets.Container):
584
613
  "Ellipsoid": ellipsod_mask,
585
614
  "Tube": tube_mask,
586
615
  "Box": box_mask,
616
+ "Membrane": membrane_mask,
587
617
  "Wedge": wedge_mask,
588
618
  "Threshold": threshold_mask,
589
619
  "Lowpass": lowpass_mask,
@@ -817,7 +847,7 @@ class AlignmentWidget(widgets.Container):
817
847
  principal_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]
818
848
 
819
849
  rotation_matrix = align_vectors(principal_eigenvector, alignment_axis)
820
- rotated_data, _ = backend.rigid_transform(
850
+ rotated_data, _ = be.rigid_transform(
821
851
  arr=active_layer.data,
822
852
  rotation_matrix=rotation_matrix,
823
853
  use_geometric_center=False,
@@ -953,7 +983,6 @@ class PointCloudWidget(widgets.Container):
953
983
  if not isinstance(layer, napari.layers.Points):
954
984
  continue
955
985
 
956
- layer.face_color = "white"
957
986
  if event == "Label":
958
987
  if len(layer.properties.get("detail", ())) == 0:
959
988
  continue
@@ -970,9 +999,7 @@ class PointCloudWidget(widgets.Container):
970
999
  layer.face_color = "score_scaled"
971
1000
  layer.face_colormap = "turbo"
972
1001
  layer.face_color_mode = "colormap"
973
-
974
1002
  layer.refresh_colors()
975
-
976
1003
  return None
977
1004
 
978
1005
  def _set_positive(self, event):
@@ -1140,9 +1167,18 @@ class MatchingWidget(widgets.Container):
1140
1167
  self.viewer = viewer
1141
1168
  self.dataframes = {}
1142
1169
 
1170
+ option_container = widgets.Container(layout="horizontal")
1171
+ self.load_target_checkbox = widgets.CheckBox(text="Load Target", value=False)
1172
+ self.load_rotations_checkbox = widgets.CheckBox(
1173
+ text="Load Rotations", value=False
1174
+ )
1175
+ option_container.append(self.load_target_checkbox)
1176
+ option_container.append(self.load_rotations_checkbox)
1177
+
1143
1178
  self.import_button = widgets.PushButton(name="Import", text="Import Pickle")
1144
1179
  self.import_button.clicked.connect(self._get_load_path)
1145
1180
 
1181
+ self.append(option_container)
1146
1182
  self.append(self.import_button)
1147
1183
 
1148
1184
  def _get_load_path(self, event):
@@ -1150,7 +1186,7 @@ class MatchingWidget(widgets.Container):
1150
1186
  self.native,
1151
1187
  "Open Pickle File...",
1152
1188
  "",
1153
- "Pickle Files (*.pickle);;All Files (*)",
1189
+ "Pickle Files (*.pickle *pickle.gz);;All Files (*)",
1154
1190
  )
1155
1191
  if filename:
1156
1192
  self._load_data(filename)
@@ -1159,14 +1195,35 @@ class MatchingWidget(widgets.Container):
1159
1195
  data = load_pickle(filename)
1160
1196
 
1161
1197
  fname = basename(filename).replace(".pickle", "")
1198
+
1199
+ if self.load_target_checkbox.value:
1200
+ try:
1201
+ target = Density.from_file(data[-1][-1].target)
1202
+ _ = self.viewer.add_image(
1203
+ data=target.data,
1204
+ name=f"{fname}_target",
1205
+ metadata={
1206
+ "origin": target.origin,
1207
+ "sampling_rate": target.sampling_rate,
1208
+ },
1209
+ )
1210
+ except Exception as e:
1211
+ msg = QMessageBox(self.native)
1212
+ msg.setIcon(QMessageBox.Warning)
1213
+ msg.setWindowTitle("Loading Error")
1214
+ msg.setText(str(e))
1215
+ msg.setStandardButtons(QMessageBox.Ok)
1216
+ msg.exec_()
1217
+
1162
1218
  if data[0].ndim == data[2].ndim:
1163
1219
  metadata = {"origin": data[-1][1], "sampling_rate": data[-1][2]}
1164
- _ = self.viewer.add_image(
1165
- data=data[2],
1166
- name=f"{fname}_rotations",
1167
- colormap="orange",
1168
- metadata=metadata,
1169
- )
1220
+ if self.load_rotations_checkbox.value:
1221
+ _ = self.viewer.add_image(
1222
+ data=data[2],
1223
+ name=f"{fname}_rotations",
1224
+ colormap="orange",
1225
+ metadata=metadata,
1226
+ )
1170
1227
  _ = self.viewer.add_image(
1171
1228
  data=data[0],
1172
1229
  name=f"{fname}_scores",
@@ -1174,11 +1231,8 @@ class MatchingWidget(widgets.Container):
1174
1231
  metadata=metadata,
1175
1232
  )
1176
1233
  return None
1177
- detail = np.zeros_like(data[2])
1178
- else:
1179
- detail = data[3]
1180
1234
 
1181
- point_properties = {"score": data[2], "detail": detail}
1235
+ point_properties = {"score": data[2], "detail": data[3]}
1182
1236
  point_properties["score_scaled"] = np.log1p(
1183
1237
  point_properties["score"] - point_properties["score"].min()
1184
1238
  )
@@ -1191,8 +1245,26 @@ class MatchingWidget(widgets.Container):
1191
1245
  )
1192
1246
 
1193
1247
 
1248
+ class CustomNapariViewer(napari.Viewer):
1249
+ """
1250
+ Custom viewer to ensure 3D image layers are by default shown as xy projection.
1251
+ """
1252
+
1253
+ def add_image(self, data, **kwargs):
1254
+ viewer_ndim = len(self.dims.order)
1255
+ layer = super().add_image(data, **kwargs)
1256
+
1257
+ try:
1258
+ # Set to xy view the first time data is opened
1259
+ if viewer_ndim != 3 and data.ndim == 3:
1260
+ self.dims.order = (2, 0, 1)
1261
+ except Exception:
1262
+ pass
1263
+ return layer
1264
+
1265
+
1194
1266
  def main():
1195
- viewer = napari.Viewer()
1267
+ viewer = CustomNapariViewer()
1196
1268
 
1197
1269
  filter_widget = FilterWidget(preprocessor, viewer)
1198
1270
  mask_widget = MaskWidget(viewer)