pytme 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. pytme-0.3.2.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/match_template.py +213 -196
  3. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/postprocess.py +40 -78
  4. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/preprocess.py +4 -5
  5. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/preprocessor_gui.py +50 -103
  6. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/pytme_runner.py +46 -69
  7. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/METADATA +3 -2
  8. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/RECORD +68 -65
  9. scripts/estimate_ram_usage.py +97 -0
  10. scripts/match_template.py +213 -196
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +40 -78
  13. scripts/preprocess.py +4 -5
  14. scripts/preprocessor_gui.py +50 -103
  15. scripts/pytme_runner.py +46 -69
  16. scripts/refine_matches.py +5 -7
  17. tests/preprocessing/test_compose.py +31 -30
  18. tests/preprocessing/test_frequency_filters.py +17 -32
  19. tests/preprocessing/test_preprocessor.py +0 -19
  20. tests/preprocessing/test_utils.py +13 -1
  21. tests/test_analyzer.py +2 -10
  22. tests/test_backends.py +47 -18
  23. tests/test_density.py +72 -13
  24. tests/test_extensions.py +1 -0
  25. tests/test_matching_cli.py +23 -9
  26. tests/test_matching_exhaustive.py +5 -5
  27. tests/test_matching_utils.py +3 -3
  28. tests/test_rotations.py +13 -23
  29. tests/test_structure.py +1 -7
  30. tme/__version__.py +1 -1
  31. tme/analyzer/aggregation.py +47 -16
  32. tme/analyzer/base.py +34 -0
  33. tme/analyzer/peaks.py +26 -13
  34. tme/analyzer/proxy.py +14 -0
  35. tme/backends/_jax_utils.py +124 -71
  36. tme/backends/cupy_backend.py +6 -19
  37. tme/backends/jax_backend.py +110 -105
  38. tme/backends/matching_backend.py +0 -17
  39. tme/backends/mlx_backend.py +0 -29
  40. tme/backends/npfftw_backend.py +100 -97
  41. tme/backends/pytorch_backend.py +65 -78
  42. tme/cli.py +2 -2
  43. tme/density.py +102 -58
  44. tme/extensions.cpython-311-darwin.so +0 -0
  45. tme/filters/_utils.py +52 -24
  46. tme/filters/bandpass.py +99 -105
  47. tme/filters/compose.py +133 -39
  48. tme/filters/ctf.py +51 -102
  49. tme/filters/reconstruction.py +67 -122
  50. tme/filters/wedge.py +296 -325
  51. tme/filters/whitening.py +39 -75
  52. tme/mask.py +2 -2
  53. tme/matching_data.py +87 -15
  54. tme/matching_exhaustive.py +70 -120
  55. tme/matching_optimization.py +9 -63
  56. tme/matching_scores.py +261 -100
  57. tme/matching_utils.py +150 -91
  58. tme/memory.py +1 -0
  59. tme/orientations.py +28 -8
  60. tme/preprocessor.py +0 -239
  61. tme/rotations.py +102 -70
  62. tme/structure.py +601 -631
  63. tme/types.py +1 -0
  64. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/estimate_memory_usage.py +0 -0
  65. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/WHEEL +0 -0
  66. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/entry_points.txt +0 -0
  67. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/licenses/LICENSE +0 -0
  68. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/top_level.txt +0 -0
@@ -87,11 +87,6 @@ def parse_args():
87
87
  help="Output prefix. Defaults to basename of first input. Extension is "
88
88
  "added with respect to chosen output format.",
89
89
  )
90
- output_group.add_argument(
91
- "--angles-clockwise",
92
- action="store_true",
93
- help="Report Euler angles in clockwise format expected by RELION.",
94
- )
95
90
  output_group.add_argument(
96
91
  "--output-format",
97
92
  choices=[
@@ -173,13 +168,6 @@ def parse_args():
173
168
  default=None,
174
169
  help="Box size of extracted subtomograms, defaults to the centered template.",
175
170
  )
176
- additional_group.add_argument(
177
- "--mask-subtomograms",
178
- action="store_true",
179
- default=False,
180
- help="Whether to mask subtomograms using the template mask. The mask will be "
181
- "rotated according to determined angles.",
182
- )
183
171
  additional_group.add_argument(
184
172
  "--invert-target-contrast",
185
173
  action="store_true",
@@ -219,20 +207,15 @@ def load_template(
219
207
  ):
220
208
  try:
221
209
  template = Density.from_file(filepath)
222
- center = np.divide(np.subtract(template.shape, 1), 2)
223
210
  template_is_density = True
224
211
  except Exception:
225
- template = Structure.from_file(filepath)
226
- center = template.center_of_mass()
227
- template = Density.from_structure(template, sampling_rate=sampling_rate)
212
+ template = Density.from_structure(filepath, sampling_rate=sampling_rate)
228
213
  template_is_density = False
229
214
 
230
- translation = np.zeros_like(center)
231
- if centering and template_is_density:
232
- template, translation = template.centered(0)
233
- center = np.divide(np.subtract(template.shape, 1), 2)
234
-
235
- return template, center, translation, template_is_density
215
+ if centering:
216
+ template = template.centered(0)
217
+ center = np.divide(np.subtract(template.shape, 1), 2)
218
+ return template, center, template_is_density
236
219
 
237
220
 
238
221
  def load_matching_output(path: str) -> List:
@@ -375,10 +358,10 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
375
358
  update = tuple(slice(0, int(x)) for x in np.minimum(out_shape, scores.shape))
376
359
  scores_out = np.full(out_shape, fill_value=0, dtype=np.float32)
377
360
  scores_out[update] = data[0][update] - scores_norm[update]
361
+ # scores_out = np.fmax(scores_out, 0, out=scores_out)
378
362
  scores_out[update] += scores_norm[update].mean()
379
363
 
380
364
  # scores_out[update] = np.divide(scores_out[update], 1 - scores_norm[update])
381
- scores_out = np.fmax(scores_out, 0, out=scores_out)
382
365
  data[0] = scores_out
383
366
 
384
367
  fg, bg = simple_stats(data[0]), simple_stats(scores_norm)
@@ -448,34 +431,24 @@ def main():
448
431
  if hasattr(cli_args, "no_centering"):
449
432
  cli_args.centering = not cli_args.no_centering
450
433
 
451
- _, template_extension = splitext(cli_args.template)
452
- ret = load_template(
434
+ template, *_ = load_template(
453
435
  filepath=cli_args.template,
454
436
  sampling_rate=sampling_rate,
455
437
  centering=cli_args.centering,
456
438
  )
457
- template, center_of_mass, translation, template_is_density = ret
458
439
 
459
440
  template_mask = template.empty
460
441
  template_mask.data[:] = 1
461
442
  if cli_args.template_mask is not None:
462
443
  template_mask = Density.from_file(cli_args.template_mask)
463
- template_mask.pad(template.shape, center=False)
464
- origin_translation = np.divide(
465
- np.subtract(template.origin, template_mask.origin), template.sampling_rate
466
- )
467
- translation = np.add(translation, origin_translation)
468
-
469
- template_mask = template_mask.rigid_transform(
470
- rotation_matrix=np.eye(template_mask.data.ndim),
471
- translation=-translation,
472
- order=1,
473
- )
444
+ if cli_args.centering:
445
+ template_mask.pad(template.shape, center=True)
474
446
 
475
447
  if args.mask_edges and args.min_boundary_distance == 0:
476
448
  max_shape = np.max(template.shape)
477
449
  args.min_boundary_distance = np.ceil(np.divide(max_shape, 2))
478
450
 
451
+ # Do the actual peak calling
479
452
  orientations = args.orientations
480
453
  if orientations is None:
481
454
  translations, rotations, scores, details = [], [], [], []
@@ -518,18 +491,20 @@ def main():
518
491
  print(f"Determined cutoff --min-score {minimum_score}.")
519
492
  args.min_score = max(minimum_score, 0)
520
493
 
521
- args.batch_dims = None
522
- if hasattr(cli_args, "batch_dims"):
523
- args.batch_dims = cli_args.batch_dims
494
+ projection_dims = None
495
+ batch_dims = getattr(cli_args, "batch_dims", None)
496
+ if getattr(cli_args, "match_projection", False):
497
+ projection_dims = batch_dims
524
498
 
525
499
  peak_caller_kwargs = {
526
500
  "shape": scores.shape,
527
501
  "num_peaks": args.num_peaks,
528
502
  "min_distance": args.min_distance,
529
503
  "min_boundary_distance": args.min_boundary_distance,
530
- "batch_dims": args.batch_dims,
531
504
  "min_score": args.min_score,
532
505
  "max_score": args.max_score,
506
+ "batch_dims": batch_dims,
507
+ "projection_dims": projection_dims,
533
508
  }
534
509
 
535
510
  peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
@@ -551,13 +526,9 @@ def main():
551
526
  exit(-1)
552
527
 
553
528
  for translation, _, score, detail in zip(*candidates):
554
- rotation_index = rotation_array[tuple(translation)]
555
- rotation = rotation_mapping.get(
556
- rotation_index, np.zeros(template.data.ndim, int)
557
- )
558
- if rotation.ndim == 2:
559
- rotation = euler_from_rotationmatrix(rotation)
560
- rotations.append(rotation)
529
+ index = rotation_array[tuple(translation)]
530
+ rotation = rotation_mapping.get(index, np.eye(template.data.ndim))
531
+ rotations.append(euler_from_rotationmatrix(rotation, seq="ZYZ"))
561
532
 
562
533
  if len(rotations):
563
534
  rotations = np.vstack(rotations).astype(float)
@@ -583,11 +554,10 @@ def main():
583
554
 
584
555
  if args.peak_oversampling > 1:
585
556
  if data[0].ndim != data[2].ndim:
586
- print(
557
+ exit(
587
558
  "Input pickle does not contain template matching scores."
588
559
  " Cannot oversample peaks."
589
560
  )
590
- exit(-1)
591
561
  peak_caller = peak_caller = PEAK_CALLERS[args.peak_caller](shape=scores.shape)
592
562
  orientations.translations = peak_caller.oversample_peaks(
593
563
  scores=data[0],
@@ -597,8 +567,6 @@ def main():
597
567
 
598
568
  if args.local_optimization:
599
569
  target = Density.from_file(cli_args.target, use_memmap=True)
600
- orientations.translations = orientations.translations.astype(np.float32)
601
- orientations.rotations = orientations.rotations.astype(np.float32)
602
570
  for index, (translation, angles, *_) in enumerate(orientations):
603
571
  score_object = create_score_object(
604
572
  score="FLC",
@@ -619,12 +587,11 @@ def main():
619
587
  x0=[*init_translation, *angles],
620
588
  )
621
589
  orientations.translations[index] = np.add(translation, center)
622
- orientations.rotations[index] = angles
590
+ orientations.rotations[index] = euler_from_rotationmatrix(
591
+ rotation_matrix, seq="ZYZ"
592
+ )
623
593
  orientations.scores[index] = score * -1
624
594
 
625
- if args.angles_clockwise:
626
- orientations.rotations *= -1
627
-
628
595
  if args.output_format in ("orientations", "relion4", "relion5"):
629
596
  file_format, extension = "text", "tsv"
630
597
 
@@ -691,12 +658,6 @@ def main():
691
658
  sampling_rate=sampling_rate,
692
659
  origin=np.multiply(cand_start, sampling_rate),
693
660
  )
694
- if args.mask_subtomograms:
695
- rotation_matrix = euler_to_rotationmatrix(orientations.rotations[index])
696
- mask_transfomed = template_mask.rigid_transform(
697
- rotation_matrix=rotation_matrix, order=1
698
- )
699
- out_density.data = out_density.data * mask_transfomed.data
700
661
  out_density.to_file(
701
662
  join(working_directory, f"{args.output_prefix}_{index}.mrc")
702
663
  )
@@ -713,10 +674,11 @@ def main():
713
674
  out = np.zeros_like(template.data)
714
675
  for index in range(len(cand_slices)):
715
676
  subset = Density(target.data[obs_slices[index]])
716
- rotation_matrix = euler_to_rotationmatrix(orientations.rotations[index])
717
677
 
678
+ # We invert to pull the local into the global reference system
679
+ matrix = euler_to_rotationmatrix(orientations.rotations[index]).T
718
680
  subset = subset.rigid_transform(
719
- rotation_matrix=np.linalg.inv(rotation_matrix),
681
+ rotation_matrix=matrix,
720
682
  order=1,
721
683
  use_geometric_center=True,
722
684
  )
@@ -728,37 +690,37 @@ def main():
728
690
  ret.to_file(f"{args.output_prefix}.mrc")
729
691
  exit(0)
730
692
 
731
- template, center, *_ = load_template(
693
+ template, center, template_is_density, *_ = load_template(
732
694
  filepath=cli_args.template,
733
695
  sampling_rate=sampling_rate,
734
696
  centering=cli_args.centering,
735
697
  )
736
698
 
699
+ _, ext = splitext(cli_args.template)
737
700
  for index, (translation, angles, *_) in enumerate(orientations):
738
- rotation_matrix = euler_to_rotationmatrix(angles)
701
+ rotation = euler_to_rotationmatrix(angles, seq="ZYZ")
739
702
  if template_is_density:
740
703
  transformed_template = template.rigid_transform(
741
- rotation_matrix=rotation_matrix, use_geometric_center=True
704
+ rotation_matrix=rotation, use_geometric_center=True
742
705
  )
743
706
  # Just adapting the coordinate system not the in-box position
744
707
  shift = np.multiply(np.subtract(translation, center), sampling_rate)
745
708
  transformed_template.origin = np.add(target_origin, shift)
746
-
747
709
  else:
748
710
  template = Structure.from_file(cli_args.template)
749
- new_center_of_mass = np.add(
750
- np.multiply(translation, sampling_rate), target_origin
751
- )
752
- translation = np.subtract(new_center_of_mass, center)
711
+ shift = np.add(np.multiply(translation, sampling_rate), target_origin)
712
+ translation = np.subtract(shift, template.center_of_mass())
713
+
714
+ # Since we move the template's center of mass to the geometric center
715
+ # during matching and analysis we use the center of mass
716
+ # directly for rotating structures into the correct orientation
753
717
  transformed_template = template.rigid_transform(
754
718
  translation=translation,
755
- rotation_matrix=rotation_matrix,
719
+ rotation_matrix=rotation,
720
+ use_geometric_center=False,
756
721
  )
757
- # template_extension should contain '.'
758
- transformed_template.to_file(
759
- f"{args.output_prefix}_{index}{template_extension}"
760
- )
761
- index += 1
722
+
723
+ transformed_template.to_file(f"{args.output_prefix}_{index}{ext}")
762
724
 
763
725
 
764
726
  if __name__ == "__main__":
@@ -126,7 +126,7 @@ def main():
126
126
  if args.align_axis is not None:
127
127
  rmat = data.align_to_axis(axis=args.align_axis, flip=args.flip_axis)
128
128
  data = data.rigid_transform(
129
- rotation_matrix=rmat, translation=0, use_geometric_center=True
129
+ rotation_matrix=rmat.T, use_geometric_center=True
130
130
  )
131
131
  data = Density.from_structure(data, sampling_rate=sampling_rate)
132
132
 
@@ -138,11 +138,11 @@ def main():
138
138
  if args.align_axis is not None:
139
139
  rmat = data.align_to_axis(axis=args.align_axis, flip=args.flip_axis)
140
140
  data = data.rigid_transform(
141
- rotation_matrix=rmat, translation=0, use_geometric_center=True
141
+ rotation_matrix=rmat.T, use_geometric_center=True
142
142
  )
143
143
 
144
144
  if not args.no_centering:
145
- data, _ = data.centered(0)
145
+ data = data.centered(0)
146
146
 
147
147
  if args.box_size is None:
148
148
  scale = np.divide(data.sampling_rate, args.sampling_rate)
@@ -177,9 +177,8 @@ def main():
177
177
  lowpass=lowpass,
178
178
  highpass=None,
179
179
  use_gaussian=True,
180
- return_real_fourier=True,
181
180
  sampling_rate=data.sampling_rate,
182
- )(shape=data.shape)["data"]
181
+ )(shape=data.shape, return_real_fourier=True)["data"]
183
182
  bpf_mask = be.to_backend_array(bpf_mask)
184
183
 
185
184
  data_ft = be.rfftn(be.to_backend_array(data.data), s=data.shape)
@@ -26,14 +26,16 @@ from tme.backends import backend as be
26
26
  from tme.rotations import align_vectors
27
27
  from tme.matching_utils import create_mask, load_pickle
28
28
  from tme import Preprocessor, Density, Orientations
29
- from tme.filters import BandPassReconstructed, CTFReconstructed
29
+ from tme.filters import BandPassReconstructed, CTFReconstructed, WedgeReconstructed
30
30
 
31
31
  preprocessor = Preprocessor()
32
32
  SLIDER_MIN, SLIDER_MAX = 0, 25
33
33
 
34
34
 
35
- def gaussian_filter(template: NDArray, sigma: float, **kwargs: dict) -> NDArray:
36
- return preprocessor.gaussian_filter(template=template, sigma=sigma, **kwargs)
35
+ def _apply_fourier_filter(arr, arr_filter):
36
+ arr_ft = np.fft.rfftn(arr, s=arr.shape)
37
+ arr_ft = np.multiply(arr_ft, arr_filter, out=arr_ft)
38
+ return np.real(np.fft.irfftn(arr_ft, s=arr.shape))
37
39
 
38
40
 
39
41
  def bandpass_filter(
@@ -48,13 +50,8 @@ def bandpass_filter(
48
50
  highpass=highpass_angstrom,
49
51
  sampling_rate=np.max(sampling_rate),
50
52
  use_gaussian=not hard_edges,
51
- return_real_fourier=True,
52
- )
53
- template_ft = np.fft.rfftn(template, s=template.shape)
54
-
55
- mask = bpf(shape=template.shape)["data"]
56
- np.multiply(template_ft, mask, out=template_ft)
57
- return np.fft.irfftn(template_ft, s=template.shape).real
53
+ )(shape=template.shape, return_real_fourier=True)["data"]
54
+ return _apply_fourier_filter(template, bpf)
58
55
 
59
56
 
60
57
  def ctf_filter(
@@ -70,9 +67,7 @@ def ctf_filter(
70
67
  ) -> NDArray:
71
68
  fast_shape = [next_fast_len(x) for x in np.multiply(template.shape, 2)]
72
69
  template_pad = be.topleft_pad(template, fast_shape)
73
- template_ft = np.fft.rfftn(template_pad, s=template_pad.shape)
74
70
  ctf = CTFReconstructed(
75
- shape=fast_shape,
76
71
  defocus_x=[defocus_angstrom],
77
72
  acceleration_voltage=acceleration_voltage * 1e3,
78
73
  spherical_aberration=spherical_aberration * 1e7,
@@ -80,22 +75,18 @@ def ctf_filter(
80
75
  phase_shift=phase_shift,
81
76
  defocus_angle=defocus_angle,
82
77
  sampling_rate=np.max(sampling_rate),
83
- return_real_fourier=True,
84
78
  flip_phase=flip_phase,
85
- )
86
- np.multiply(template_ft, ctf()["data"], out=template_ft)
87
- template_pad = np.fft.irfftn(template_ft, s=template_pad.shape).real
88
- template = be.topleft_pad(template_pad, template.shape)
89
- return template
79
+ )(shape=template.shape, return_real_fourier=True)["data"]
80
+ template = _apply_fourier_filter(template, ctf)
81
+ return be.topleft_pad(template_pad, template.shape)
90
82
 
91
83
 
92
- def difference_of_gaussian_filter(
93
- template: NDArray, sigmas: Tuple[float, float], **kwargs: dict
94
- ) -> NDArray:
95
- low_sigma, high_sigma = sigmas
96
- return preprocessor.difference_of_gaussian_filter(
97
- template=template, low_sigma=low_sigma, high_sigma=high_sigma, **kwargs
98
- )
84
+ def gaussian_filter(template: NDArray, sigma: float, **kwargs: dict) -> NDArray:
85
+ return preprocessor.gaussian_filter(template=template, sigma=sigma, **kwargs)
86
+
87
+
88
+ def difference_of_gaussian_filter(template, sigmas: Tuple[float, float], **kwargs):
89
+ return gaussian_filter(template, sigmas[0]) - gaussian_filter(template, sigmas[1])
99
90
 
100
91
 
101
92
  def edge_gaussian_filter(
@@ -132,11 +123,7 @@ def local_gaussian_filter(
132
123
  )
133
124
 
134
125
 
135
- def mean(
136
- template: NDArray,
137
- width: int,
138
- **kwargs: dict,
139
- ) -> NDArray:
126
+ def mean(template: NDArray, width: int, **kwargs: dict) -> NDArray:
140
127
  return preprocessor.mean_filter(template=template, width=width)
141
128
 
142
129
 
@@ -147,45 +134,23 @@ def wedge(
147
134
  tilt_step: float = 0,
148
135
  opening_axis: int = 2,
149
136
  tilt_axis: int = 0,
150
- omit_negative_frequencies: bool = False,
151
- infinite_plane: bool = False,
152
- weight_angle: bool = False,
153
137
  **kwargs,
154
138
  ) -> NDArray:
155
- template_ft = np.fft.fftn(template)
156
-
157
- if tilt_step <= 0:
158
- wedge_mask = preprocessor.continuous_wedge_mask(
159
- start_tilt=tilt_start,
160
- stop_tilt=tilt_stop,
161
- tilt_axis=tilt_axis,
162
- opening_axis=opening_axis,
163
- shape=template.shape,
164
- omit_negative_frequencies=omit_negative_frequencies,
165
- infinite_plane=infinite_plane,
166
- )
167
- else:
168
- weights = None
169
- tilt_angles = np.arange(-tilt_start, tilt_stop + tilt_step, tilt_step)
170
- if weight_angle:
171
- weights = np.cos(np.radians(tilt_angles))
172
-
173
- wedge_mask = preprocessor.step_wedge_mask(
174
- tilt_angles=tilt_angles,
175
- tilt_axis=tilt_axis,
176
- opening_axis=opening_axis,
177
- shape=template.shape,
178
- weights=weights,
179
- omit_negative_frequencies=omit_negative_frequencies,
180
- )
181
-
182
- np.multiply(template_ft, wedge_mask, out=template_ft)
183
- template = np.real(np.fft.ifftn(template_ft))
184
- return template
139
+ mask = wedge_mask(
140
+ template=template,
141
+ fftshift=False,
142
+ tilt_start=tilt_start,
143
+ tilt_stop=tilt_stop,
144
+ tilt_step=tilt_step,
145
+ return_real_fourier=True,
146
+ weight_angle=False,
147
+ )
148
+ return _apply_fourier_filter(template, mask)
185
149
 
186
150
 
187
151
  def compute_power_spectrum(template: NDArray) -> NDArray:
188
- return np.fft.fftshift(np.log(np.abs(np.fft.fftn(template))))
152
+ return np.fft.fftshift(np.log(1 + np.abs(np.fft.fftn(template))))
153
+ # return np.fft.fftshift(np.log(np.abs(np.fft.fftn(template))))
189
154
 
190
155
 
191
156
  def invert_contrast(template: NDArray) -> NDArray:
@@ -474,6 +439,7 @@ def membrane_mask(
474
439
  **kwargs,
475
440
  ) -> NDArray:
476
441
  return create_mask(
442
+ center=(center_x, center_y, center_z),
477
443
  mask_type="membrane",
478
444
  shape=template.shape,
479
445
  radius=radius,
@@ -514,39 +480,30 @@ def wedge_mask(
514
480
  tilt_step: float = 0,
515
481
  opening_axis: int = 2,
516
482
  tilt_axis: int = 0,
517
- omit_negative_frequencies: bool = False,
518
- infinite_plane: bool = False,
519
- weight_angle: bool = False,
520
483
  **kwargs,
521
484
  ) -> NDArray:
522
- if tilt_step <= 0:
523
- wedge_mask = preprocessor.continuous_wedge_mask(
524
- start_tilt=tilt_start,
525
- stop_tilt=tilt_stop,
526
- tilt_axis=tilt_axis,
527
- opening_axis=opening_axis,
528
- shape=template.shape,
529
- omit_negative_frequencies=omit_negative_frequencies,
530
- infinite_plane=infinite_plane,
531
- )
532
- wedge_mask = np.fft.fftshift(wedge_mask)
533
- return wedge_mask
534
-
535
- weights = None
536
- tilt_angles = np.arange(-tilt_start, tilt_stop + tilt_step, tilt_step)
537
- if weight_angle:
538
- weights = np.cos(np.radians(tilt_angles))
539
-
540
- wedge_mask = preprocessor.step_wedge_mask(
541
- tilt_angles=tilt_angles,
485
+ angles = (tilt_start, tilt_stop)
486
+ continuous_wedge = tilt_step == 0
487
+ if not continuous_wedge:
488
+ angles = np.arange(-tilt_start, tilt_stop + tilt_step, tilt_step)
489
+
490
+ return_real_fourier = kwargs.get("return_real_fourier", False)
491
+ func = WedgeReconstructed(
492
+ angles=angles,
542
493
  tilt_axis=tilt_axis,
543
494
  opening_axis=opening_axis,
544
- shape=template.shape,
545
- weights=weights,
546
- omit_negative_frequencies=omit_negative_frequencies,
495
+ frequency_cutoff=0.5,
496
+ create_continuous_wedge=continuous_wedge,
497
+ weight_wedge=kwargs.get("weight_angle", False),
547
498
  )
548
-
549
- wedge_mask = np.fft.fftshift(wedge_mask)
499
+ wedge_mask = func(shape=template.shape, return_real_fourier=return_real_fourier)[
500
+ "data"
501
+ ]
502
+ if kwargs.get("fftshift", True):
503
+ axes = [i for i in range(wedge_mask.ndim)]
504
+ if return_real_fourier:
505
+ _ = axes.pop(-1)
506
+ wedge_mask = np.fft.fftshift(wedge_mask, axes=axes)
550
507
  return wedge_mask
551
508
 
552
509
 
@@ -569,8 +526,7 @@ def threshold_mask(
569
526
  mask[mask < np.exp(-np.square(sigma))] = 0
570
527
 
571
528
  if invert:
572
- np.invert(mask, out=mask)
573
-
529
+ mask = 1 - mask
574
530
  return mask
575
531
 
576
532
 
@@ -757,15 +713,6 @@ class MaskWidget(widgets.Container):
757
713
  if self.method_dropdown.value == "Shape":
758
714
  new_layer.metadata = {}
759
715
 
760
- # origin_layer = metadata["origin_layer"]
761
- # if origin_layer in self.viewer.layers:
762
- # origin_layer = self.viewer.layers[origin_layer]
763
- # if np.allclose(origin_layer.data.shape, processed_data.shape):
764
- # in_mask = np.sum(np.fmax(origin_layer.data * processed_data, 0))
765
- # in_mask /= np.sum(np.fmax(origin_layer.data, 0))
766
- # in_mask *= 100
767
- # self.density_field.value = f"Positive Density in Mask: {in_mask:.2f}%"
768
-
769
716
 
770
717
  class AlignmentWidget(widgets.Container):
771
718
  def __init__(self, viewer):