pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +1 -5
  2. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/match_template.py +177 -226
  3. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/postprocess.py +69 -47
  4. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocess.py +10 -23
  5. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +98 -28
  6. pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
  7. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/METADATA +15 -15
  8. pytme-0.3.1.dist-info/RECORD +133 -0
  9. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +1 -0
  10. pytme-0.3.1.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +1 -5
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +118 -99
  14. scripts/match_template.py +177 -226
  15. scripts/match_template_filters.py +1200 -0
  16. scripts/postprocess.py +69 -47
  17. scripts/preprocess.py +10 -23
  18. scripts/preprocessor_gui.py +98 -28
  19. scripts/pytme_runner.py +1223 -0
  20. scripts/refine_matches.py +156 -387
  21. tests/data/.DS_Store +0 -0
  22. tests/data/Blurring/.DS_Store +0 -0
  23. tests/data/Maps/.DS_Store +0 -0
  24. tests/data/Raw/.DS_Store +0 -0
  25. tests/data/Structures/.DS_Store +0 -0
  26. tests/preprocessing/test_frequency_filters.py +19 -10
  27. tests/preprocessing/test_utils.py +18 -0
  28. tests/test_analyzer.py +122 -122
  29. tests/test_backends.py +4 -9
  30. tests/test_density.py +0 -1
  31. tests/test_matching_cli.py +30 -30
  32. tests/test_matching_data.py +5 -5
  33. tests/test_matching_utils.py +11 -61
  34. tests/test_rotations.py +1 -1
  35. tme/__version__.py +1 -1
  36. tme/analyzer/__init__.py +1 -1
  37. tme/analyzer/_utils.py +5 -8
  38. tme/analyzer/aggregation.py +28 -9
  39. tme/analyzer/base.py +25 -36
  40. tme/analyzer/peaks.py +49 -122
  41. tme/analyzer/proxy.py +1 -0
  42. tme/backends/_jax_utils.py +31 -28
  43. tme/backends/_numpyfftw_utils.py +270 -0
  44. tme/backends/cupy_backend.py +11 -54
  45. tme/backends/jax_backend.py +72 -48
  46. tme/backends/matching_backend.py +6 -51
  47. tme/backends/mlx_backend.py +1 -27
  48. tme/backends/npfftw_backend.py +95 -90
  49. tme/backends/pytorch_backend.py +5 -26
  50. tme/density.py +7 -10
  51. tme/extensions.cpython-311-darwin.so +0 -0
  52. tme/filters/__init__.py +2 -2
  53. tme/filters/_utils.py +32 -7
  54. tme/filters/bandpass.py +225 -186
  55. tme/filters/ctf.py +138 -87
  56. tme/filters/reconstruction.py +38 -9
  57. tme/filters/wedge.py +98 -112
  58. tme/filters/whitening.py +1 -6
  59. tme/mask.py +341 -0
  60. tme/matching_data.py +20 -44
  61. tme/matching_exhaustive.py +46 -56
  62. tme/matching_optimization.py +2 -1
  63. tme/matching_scores.py +216 -412
  64. tme/matching_utils.py +82 -424
  65. tme/memory.py +1 -1
  66. tme/orientations.py +16 -8
  67. tme/parser.py +109 -29
  68. tme/preprocessor.py +2 -2
  69. tme/rotations.py +1 -1
  70. pytme-0.3b0.dist-info/RECORD +0 -122
  71. pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
  72. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
  73. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
tme/matching_utils.py CHANGED
@@ -12,40 +12,36 @@ from shutil import move
12
12
  from joblib import Parallel
13
13
  from tempfile import mkstemp
14
14
  from itertools import product
15
+ from gzip import open as gzip_open
16
+ from typing import Tuple, Dict, Callable
15
17
  from concurrent.futures import ThreadPoolExecutor
16
- from typing import Tuple, Dict, Callable, Optional
17
18
 
18
19
  import numpy as np
19
20
  from scipy.spatial import ConvexHull
20
- from scipy.ndimage import gaussian_filter
21
21
 
22
22
  from .backends import backend as be
23
23
  from .memory import estimate_memory_usage
24
24
  from .types import NDArray, BackendArray
25
25
 
26
26
 
27
- def noop(*args, **kwargs):
28
- pass
29
-
30
-
31
- def identity(arr, *args):
27
+ def identity(arr, *args, **kwargs):
32
28
  return arr
33
29
 
34
30
 
35
31
  def conditional_execute(
36
32
  func: Callable,
37
- execute_operation: bool,
38
- alt_func: Callable = noop,
33
+ execute_operation: bool = False,
34
+ alt_func: Callable = identity,
39
35
  ) -> Callable:
40
36
  """
41
- Return the given function or a no-op function based on execute_operation.
37
+ Return the given function or alternative function based on execute_operation.
42
38
 
43
39
  Parameters
44
40
  ----------
45
41
  func : Callable
46
42
  Callable.
47
43
  alt_func : Callable
48
- Callable to return if ``execute_operation`` is False, no-op by default.
44
+ Callable to return if ``execute_operation`` is False, identity by default.
49
45
  execute_operation : bool
50
46
  Whether to return ``func`` or a ``alt_func`` function.
51
47
 
@@ -175,6 +171,12 @@ def memmap_to_array(arr: NDArray) -> NDArray:
175
171
  return arr
176
172
 
177
173
 
174
+ def is_gzipped(filename: str) -> bool:
175
+ """Check if a file is a gzip file by reading its magic number."""
176
+ with open(filename, "rb") as f:
177
+ return f.read(2) == b"\x1f\x8b"
178
+
179
+
178
180
  def write_pickle(data: object, filename: str) -> None:
179
181
  """
180
182
  Serialize and write data to a file invalidating the input data.
@@ -242,7 +244,11 @@ def load_pickle(filename: str) -> object:
242
244
  return ret
243
245
 
244
246
  items = []
245
- with open(filename, "rb") as ifile:
247
+ func = open
248
+ if is_gzipped(filename):
249
+ func = gzip_open
250
+
251
+ with func(filename, "rb") as ifile:
246
252
  for data in _load_pickle(ifile):
247
253
  if isinstance(data, tuple):
248
254
  if _is_pickle_memmap(data):
@@ -409,7 +415,7 @@ def compute_parallelization_schedule(
409
415
  return splits, core_assignment
410
416
 
411
417
 
412
- def _center_slice(current_shape: Tuple[int], new_shape: Tuple[int]) -> Tuple[slice]:
418
+ def center_slice(current_shape: Tuple[int], new_shape: Tuple[int]) -> Tuple[slice]:
413
419
  """Extract the center slice of ``current_shape`` to retrieve ``new_shape``."""
414
420
  new_shape = tuple(int(x) for x in new_shape)
415
421
  current_shape = tuple(int(x) for x in current_shape)
@@ -419,60 +425,12 @@ def _center_slice(current_shape: Tuple[int], new_shape: Tuple[int]) -> Tuple[sli
419
425
  return box
420
426
 
421
427
 
422
- def centered(arr: BackendArray, new_shape: Tuple[int]) -> BackendArray:
423
- """
424
- Extract the centered portion of an array based on a new shape.
425
-
426
- Parameters
427
- ----------
428
- arr : BackendArray
429
- Input data.
430
- new_shape : tuple of ints
431
- Desired shape for the central portion.
432
-
433
- Returns
434
- -------
435
- BackendArray
436
- Central portion of the array with shape ``new_shape``.
437
-
438
- References
439
- ----------
440
- .. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py#L388
441
- """
442
- box = _center_slice(arr.shape, new_shape=new_shape)
443
- return arr[box]
444
-
445
-
446
- def centered_mask(arr: BackendArray, new_shape: Tuple[int]) -> BackendArray:
447
- """
448
- Mask the centered portion of an array based on a new shape.
449
-
450
- Parameters
451
- ----------
452
- arr : BackendArray
453
- Input data.
454
- new_shape : tuple of ints
455
- Desired shape for the mask.
456
-
457
- Returns
458
- -------
459
- BackendArray
460
- Array with central portion unmasked and the rest set to 0.
461
- """
462
- box = _center_slice(arr.shape, new_shape=new_shape)
463
- mask = np.zeros_like(arr)
464
- mask[box] = 1
465
- arr *= mask
466
- return arr
467
-
468
-
469
428
  def apply_convolution_mode(
470
429
  arr: BackendArray,
471
430
  convolution_mode: str,
472
431
  s1: Tuple[int],
473
432
  s2: Tuple[int],
474
433
  convolution_shape: Tuple[int] = None,
475
- mask_output: bool = False,
476
434
  ) -> BackendArray:
477
435
  """
478
436
  Applies convolution_mode to ``arr``.
@@ -497,9 +455,6 @@ def apply_convolution_mode(
497
455
  Tuple of integers corresponding to shape of convolution array 2.
498
456
  convolution_shape : tuple of ints, optional
499
457
  Size of the actually computed convolution. s1 + s2 - 1 by default.
500
- mask_output : bool, optional
501
- Whether to mask values outside of convolution_mode rather than
502
- removing them. Defaults to False.
503
458
 
504
459
  Returns
505
460
  -------
@@ -514,14 +469,13 @@ def apply_convolution_mode(
514
469
  if convolution_mode not in ("full", "same", "valid"):
515
470
  raise ValueError("Supported convolution_mode are 'full', 'same' and 'valid'.")
516
471
 
517
- func = centered_mask if mask_output else centered
518
472
  if convolution_mode == "full":
519
- return arr
473
+ subset = ...
520
474
  elif convolution_mode == "same":
521
- return func(arr, s1)
475
+ subset = center_slice(arr.shape, s1)
522
476
  elif convolution_mode == "valid":
523
- valid_shape = [s1[i] - s2[i] + s2[i] % 2 for i in range(arr.ndim)]
524
- return func(arr, valid_shape)
477
+ subset = center_slice(arr.shape, [x - y + 1 for x, y in zip(s1, s2)])
478
+ return arr[subset]
525
479
 
526
480
 
527
481
  def compute_full_convolution_index(
@@ -716,350 +670,8 @@ def minimum_enclosing_box(
716
670
  return shape
717
671
 
718
672
 
719
- def create_mask(mask_type: str, sigma_decay: float = 0, **kwargs) -> NDArray:
720
- """
721
- Creates a mask of the specified type.
722
-
723
- Parameters
724
- ----------
725
- mask_type : str
726
- Type of the mask to be created. Can be one of:
727
-
728
- +---------+----------------------------------------------------------+
729
- | box | Box mask (see :py:meth:`box_mask`) |
730
- +---------+----------------------------------------------------------+
731
- | tube | Cylindrical mask (see :py:meth:`tube_mask`) |
732
- +---------+----------------------------------------------------------+
733
- | ellipse | Ellipsoidal mask (see :py:meth:`elliptical_mask`) |
734
- +---------+----------------------------------------------------------+
735
- sigma_decay : float, optional
736
- Smoothing along mask edges using a Gaussian filter, 0 by default.
737
- kwargs : dict
738
- Parameters passed to the indivdual mask creation funcitons.
739
-
740
- Returns
741
- -------
742
- NDArray
743
- The created mask.
744
-
745
- Raises
746
- ------
747
- ValueError
748
- If the mask_type is invalid.
749
- """
750
- mapping = {"ellipse": elliptical_mask, "box": box_mask, "tube": tube_mask}
751
- if mask_type not in mapping:
752
- raise ValueError(f"mask_type has to be one of {','.join(mapping.keys())}")
753
-
754
- mask = mapping[mask_type](**kwargs)
755
- if sigma_decay > 0:
756
- mask_filter = gaussian_filter(mask.astype(np.float32), sigma=sigma_decay)
757
- mask = np.add(mask, (1 - mask) * mask_filter)
758
- mask[mask < np.exp(-np.square(sigma_decay))] = 0
759
-
760
- return mask
761
-
762
-
763
- def elliptical_mask(
764
- shape: Tuple[int],
765
- radius: Tuple[float],
766
- center: Optional[Tuple[float]] = None,
767
- orientation: Optional[NDArray] = None,
768
- ) -> NDArray:
769
- """
770
- Creates an ellipsoidal mask.
771
-
772
- Parameters
773
- ----------
774
- shape : tuple of ints
775
- Shape of the mask to be created.
776
- radius : tuple of floats
777
- Radius of the mask.
778
- center : tuple of floats, optional
779
- Center of the mask, default to shape // 2.
780
- orientation : NDArray, optional.
781
- Orientation of the mask as rotation matrix with shape (d,d).
782
-
783
- Returns
784
- -------
785
- NDArray
786
- The created ellipsoidal mask.
787
-
788
- Raises
789
- ------
790
- ValueError
791
- If the length of center and radius is not one or the same as shape.
792
-
793
- Examples
794
- --------
795
- >>> from tme.matching_utils import elliptical_mask
796
- >>> mask = elliptical_mask(shape=(20,20), radius=(5,5), center=(10,10))
797
- """
798
- shape, radius = np.asarray(shape), np.asarray(radius)
799
-
800
- shape = shape.astype(int)
801
- if center is None:
802
- center = np.divide(shape, 2).astype(int)
803
-
804
- center = np.asarray(center, dtype=np.float32)
805
- radius = np.repeat(radius, shape.size // radius.size)
806
- center = np.repeat(center, shape.size // center.size)
807
- if radius.size != shape.size:
808
- raise ValueError("Length of radius has to be either one or match shape.")
809
- if center.size != shape.size:
810
- raise ValueError("Length of center has to be either one or match shape.")
811
-
812
- n = shape.size
813
- center = center.reshape((-1,) + (1,) * n)
814
- radius = radius.reshape((-1,) + (1,) * n)
815
-
816
- indices = np.indices(shape, dtype=np.float32) - center
817
- if orientation is not None:
818
- return_shape = indices.shape
819
- indices = indices.reshape(n, -1)
820
- rigid_transform(
821
- coordinates=indices,
822
- rotation_matrix=np.asarray(orientation),
823
- out=indices,
824
- translation=np.zeros(n),
825
- use_geometric_center=False,
826
- )
827
- indices = indices.reshape(*return_shape)
828
-
829
- mask = np.linalg.norm(indices / radius, axis=0)
830
- mask = (mask <= 1).astype(int)
831
-
832
- return mask
833
-
834
-
835
- def tube_mask2(
836
- shape: Tuple[int],
837
- inner_radius: float,
838
- outer_radius: float,
839
- height: int,
840
- symmetry_axis: Optional[int] = 2,
841
- center: Optional[Tuple[float]] = None,
842
- orientation: Optional[NDArray] = None,
843
- epsilon: float = 0.5,
844
- ) -> NDArray:
845
- """
846
- Creates a tube mask.
847
-
848
- Parameters
849
- ----------
850
- shape : tuple
851
- Shape of the mask to be created.
852
- inner_radius : float
853
- Inner radius of the tube.
854
- outer_radius : float
855
- Outer radius of the tube.
856
- height : int
857
- Height of the tube.
858
- symmetry_axis : int, optional
859
- The axis of symmetry for the tube, defaults to 2.
860
- center : tuple of float, optional.
861
- Center of the mask, defaults to shape // 2.
862
- orientation : NDArray, optional.
863
- Orientation of the mask as rotation matrix with shape (d,d).
864
- epsilon : float, optional
865
- Tolerance to handle discretization errors, defaults to 0.5.
866
-
867
- Returns
868
- -------
869
- NDArray
870
- The created tube mask.
871
-
872
- Raises
873
- ------
874
- ValueError
875
- If ``inner_radius`` is larger than ``outer_radius``.
876
- If ``center`` and ``shape`` do not have the same length.
877
- """
878
- shape = np.asarray(shape, dtype=int)
879
-
880
- if center is None:
881
- center = np.divide(shape, 2).astype(int)
882
-
883
- center = np.asarray(center, dtype=np.float32)
884
- center = np.repeat(center, shape.size // center.size)
885
- if inner_radius > outer_radius:
886
- raise ValueError("inner_radius should be smaller than outer_radius.")
887
- if symmetry_axis > len(shape):
888
- raise ValueError(f"symmetry_axis can be not larger than {len(shape)}.")
889
- if center.size != shape.size:
890
- raise ValueError("Length of center has to be either one or match shape.")
891
-
892
- n = shape.size
893
- center = center.reshape((-1,) + (1,) * n)
894
- indices = np.indices(shape, dtype=np.float32) - center
895
- if orientation is not None:
896
- return_shape = indices.shape
897
- indices = indices.reshape(n, -1)
898
- rigid_transform(
899
- coordinates=indices,
900
- rotation_matrix=np.asarray(orientation),
901
- out=indices,
902
- translation=np.zeros(n),
903
- use_geometric_center=False,
904
- )
905
- indices = indices.reshape(*return_shape)
906
-
907
- mask = np.zeros(shape, dtype=bool)
908
- sq_dist = np.zeros(shape)
909
- for i in range(len(shape)):
910
- if i == symmetry_axis:
911
- continue
912
- sq_dist += indices[i] ** 2
913
-
914
- sym_coord = indices[symmetry_axis]
915
- half_height = height / 2
916
- height_mask = np.abs(sym_coord) <= half_height
917
-
918
- inner_mask = 1
919
- if inner_radius > epsilon:
920
- inner_mask = sq_dist >= ((inner_radius) ** 2 - epsilon)
921
-
922
- height_mask = np.abs(sym_coord) <= (half_height + epsilon)
923
- outer_mask = sq_dist <= ((outer_radius) ** 2 + epsilon)
924
-
925
- mask = height_mask & inner_mask & outer_mask
926
- return mask
927
-
928
-
929
- def box_mask(shape: Tuple[int], center: Tuple[int], height: Tuple[int]) -> np.ndarray:
930
- """
931
- Creates a box mask centered around the provided center point.
932
-
933
- Parameters
934
- ----------
935
- shape : tuple of ints
936
- Shape of the output array.
937
- center : tuple of ints
938
- Center point coordinates of the box.
939
- height : tuple of ints
940
- Height (side length) of the box along each axis.
941
-
942
- Returns
943
- -------
944
- NDArray
945
- The created box mask.
946
-
947
- Raises
948
- ------
949
- ValueError
950
- If ``shape`` and ``center`` do not have the same length.
951
- If ``center`` and ``height`` do not have the same length.
952
- """
953
- if len(shape) != len(center) or len(center) != len(height):
954
- raise ValueError("The length of shape, center, and height must be consistent.")
955
-
956
- shape = tuple(int(x) for x in shape)
957
- center, height = np.array(center, dtype=int), np.array(height, dtype=int)
958
-
959
- half_heights = height // 2
960
- starts = np.maximum(center - half_heights, 0)
961
- stops = np.minimum(center + half_heights + np.mod(height, 2) + 1, shape)
962
- slice_indices = tuple(slice(*coord) for coord in zip(starts, stops))
963
-
964
- out = np.zeros(shape)
965
- out[slice_indices] = 1
966
- return out
967
-
968
-
969
- def tube_mask(
970
- shape: Tuple[int],
971
- symmetry_axis: int,
972
- base_center: Tuple[int],
973
- inner_radius: float,
974
- outer_radius: float,
975
- height: int,
976
- ) -> NDArray:
977
- """
978
- Creates a tube mask.
979
-
980
- Parameters
981
- ----------
982
- shape : tuple
983
- Shape of the mask to be created.
984
- symmetry_axis : int
985
- The axis of symmetry for the tube.
986
- base_center : tuple
987
- Center of the tube.
988
- inner_radius : float
989
- Inner radius of the tube.
990
- outer_radius : float
991
- Outer radius of the tube.
992
- height : int
993
- Height of the tube.
994
-
995
- Returns
996
- -------
997
- NDArray
998
- The created tube mask.
999
-
1000
- Raises
1001
- ------
1002
- ValueError
1003
- If ``inner_radius`` is larger than ``outer_radius``.
1004
- If ``height`` is larger than the symmetry axis.
1005
- If ``base_center`` and ``shape`` do not have the same length.
1006
- """
1007
- if inner_radius > outer_radius:
1008
- raise ValueError("inner_radius should be smaller than outer_radius.")
1009
-
1010
- if height > shape[symmetry_axis]:
1011
- raise ValueError(f"Height can be no larger than {shape[symmetry_axis]}.")
1012
-
1013
- if symmetry_axis > len(shape):
1014
- raise ValueError(f"symmetry_axis can be not larger than {len(shape)}.")
1015
-
1016
- if len(base_center) != len(shape):
1017
- raise ValueError("shape and base_center need to have the same length.")
1018
-
1019
- shape = tuple(int(x) for x in shape)
1020
- circle_shape = tuple(b for ix, b in enumerate(shape) if ix != symmetry_axis)
1021
- circle_center = tuple(b for ix, b in enumerate(base_center) if ix != symmetry_axis)
1022
-
1023
- inner_circle = np.zeros(circle_shape)
1024
- outer_circle = np.zeros_like(inner_circle)
1025
- if inner_radius > 0:
1026
- inner_circle = create_mask(
1027
- mask_type="ellipse",
1028
- shape=circle_shape,
1029
- radius=inner_radius,
1030
- center=circle_center,
1031
- )
1032
- if outer_radius > 0:
1033
- outer_circle = create_mask(
1034
- mask_type="ellipse",
1035
- shape=circle_shape,
1036
- radius=outer_radius,
1037
- center=circle_center,
1038
- )
1039
- circle = outer_circle - inner_circle
1040
- circle = np.expand_dims(circle, axis=symmetry_axis)
1041
-
1042
- center = base_center[symmetry_axis]
1043
- start_idx = int(center - height // 2)
1044
- stop_idx = int(center + height // 2 + height % 2)
1045
-
1046
- start_idx, stop_idx = max(start_idx, 0), min(stop_idx, shape[symmetry_axis])
1047
-
1048
- slice_indices = tuple(
1049
- slice(None) if i != symmetry_axis else slice(start_idx, stop_idx)
1050
- for i in range(len(shape))
1051
- )
1052
- tube = np.zeros(shape)
1053
- tube[slice_indices] = circle
1054
-
1055
- return tube
1056
-
1057
-
1058
673
  def scramble_phases(
1059
- arr: NDArray,
1060
- noise_proportion: float = 0.5,
1061
- seed: int = 42,
1062
- normalize_power: bool = False,
674
+ arr: NDArray, noise_proportion: float = 1.0, seed: int = 42, **kwargs
1063
675
  ) -> NDArray:
1064
676
  """
1065
677
  Perform random phase scrambling of ``arr``.
@@ -1069,7 +681,7 @@ def scramble_phases(
1069
681
  arr : NDArray
1070
682
  Input data.
1071
683
  noise_proportion : float, optional
1072
- Proportion of scrambled phases, 0.5 by default.
684
+ Proportion of scrambled phases, 1.0 by default.
1073
685
  seed : int, optional
1074
686
  The seed for the random phase scrambling, 42 by default.
1075
687
  normalize_power : bool, optional
@@ -1080,24 +692,22 @@ def scramble_phases(
1080
692
  NDArray
1081
693
  Phase scrambled version of ``arr``.
1082
694
  """
695
+ from tme.filters._utils import fftfreqn
696
+
1083
697
  np.random.seed(seed)
1084
698
  noise_proportion = max(min(noise_proportion, 1), 0)
1085
699
 
1086
700
  arr_fft = np.fft.fftn(arr)
1087
701
  amp, ph = np.abs(arr_fft), np.angle(arr_fft)
1088
702
 
1089
- ph_noise = np.random.permutation(ph)
1090
- ph_new = ph * (1 - noise_proportion) + ph_noise * noise_proportion
1091
- ret = np.real(np.fft.ifftn(amp * np.exp(1j * ph_new)))
1092
-
1093
- if normalize_power:
1094
- np.divide(ret - ret.min(), ret.max() - ret.min(), out=ret)
1095
- np.multiply(ret, np.subtract(arr.max(), arr.min()), out=ret)
1096
- np.add(ret, arr.min(), out=ret)
1097
- scaling = np.divide(np.abs(arr).sum(), np.abs(ret).sum())
1098
- np.multiply(ret, scaling, out=ret)
703
+ # Scrambling up to nyquist gives more uniform noise distribution
704
+ mask = np.fft.ifftshift(
705
+ fftfreqn(arr_fft.shape, sampling_rate=1, compute_euclidean_norm=True) <= 0.5
706
+ )
1099
707
 
1100
- return ret
708
+ ph_noise = np.random.permutation(ph[mask])
709
+ ph[mask] = ph[mask] * (1 - noise_proportion) + ph_noise * noise_proportion
710
+ return np.real(np.fft.ifftn(amp * np.exp(1j * ph)))
1101
711
 
1102
712
 
1103
713
  def compute_extraction_box(
@@ -1153,6 +763,54 @@ def compute_extraction_box(
1153
763
  return obs_beg_clamp, obs_end_clamp, cand_beg, cand_end, keep
1154
764
 
1155
765
 
766
+ def create_mask(mask_type: str, sigma_decay: float = 0, **kwargs) -> NDArray:
767
+ """
768
+ Creates a mask of the specified type.
769
+
770
+ Parameters
771
+ ----------
772
+ mask_type : str
773
+ Type of the mask to be created. Can be one of:
774
+
775
+ +----------+---------------------------------------------------------+
776
+ | box | Box mask (see :py:meth:`box_mask`) |
777
+ +----------+---------------------------------------------------------+
778
+ | tube | Cylindrical mask (see :py:meth:`tube_mask`) |
779
+ +----------+---------------------------------------------------------+
780
+ | membrane | Cylindrical mask (see :py:meth:`membrane_mask`) |
781
+ +----------+---------------------------------------------------------+
782
+ | ellipse | Ellipsoidal mask (see :py:meth:`elliptical_mask`) |
783
+ +----------+---------------------------------------------------------+
784
+ sigma_decay : float, optional
785
+ Smoothing along mask edges using a Gaussian filter, 0 by default.
786
+ kwargs : dict
787
+ Parameters passed to the indivdual mask creation funcitons.
788
+
789
+ Returns
790
+ -------
791
+ NDArray
792
+ The created mask.
793
+
794
+ Raises
795
+ ------
796
+ ValueError
797
+ If the mask_type is invalid.
798
+ """
799
+ from .mask import elliptical_mask, box_mask, tube_mask, membrane_mask
800
+
801
+ mapping = {
802
+ "ellipse": elliptical_mask,
803
+ "box": box_mask,
804
+ "tube": tube_mask,
805
+ "membrane": membrane_mask,
806
+ }
807
+ if mask_type not in mapping:
808
+ raise ValueError(f"mask_type has to be one of {','.join(mapping.keys())}")
809
+
810
+ mask = mapping[mask_type](**kwargs, sigma_decay=sigma_decay)
811
+ return mask
812
+
813
+
1156
814
  class TqdmParallel(Parallel):
1157
815
  """
1158
816
  A minimal Parallel implementation using tqdm for progress reporting.
tme/memory.py CHANGED
@@ -244,7 +244,7 @@ MATCHING_MEMORY_REGISTRY = {
244
244
  "PeakCallerMaximumFilter": PeakCallerMaximumFilterMemoryUsage,
245
245
  "cupy": CupyBackendMemoryUsage,
246
246
  "pytorch": CupyBackendMemoryUsage,
247
- "batchFLCSpherical": FLCSphericalMaskMemoryUsage,
247
+ "batchFLCSphericalMask": FLCSphericalMaskMemoryUsage,
248
248
  "batchFLC": FLCMemoryUsage,
249
249
  }
250
250
 
tme/orientations.py CHANGED
@@ -82,7 +82,7 @@ class Orientations:
82
82
  self.translations = np.array(self.translations).astype(np.float32)
83
83
  self.rotations = np.array(self.rotations).astype(np.float32)
84
84
  self.scores = np.array(self.scores).astype(np.float32)
85
- self.details = np.array(self.details)
85
+ self.details = np.array(self.details).astype(np.float32)
86
86
  n_orientations = set(
87
87
  [
88
88
  self.translations.shape[0],
@@ -471,10 +471,7 @@ class Orientations:
471
471
  rotation = np.zeros(translation.shape, dtype=np.float32)
472
472
 
473
473
  header_order = tuple(x for x in header if x in NAMES)
474
- header_order = zip(header_order, range(len(header_order)))
475
- sort_order = tuple(
476
- x[1] for x in sorted(header_order, key=lambda x: x[0], reverse=False)
477
- )
474
+ sort_order = tuple(NAMES.index(x) for x in header_order)
478
475
  translation = translation[..., sort_order]
479
476
 
480
477
  header_order = tuple(
@@ -492,9 +489,12 @@ class Orientations:
492
489
  def _from_star(
493
490
  cls, filename: str, delimiter: str = "\t"
494
491
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
495
- ret = StarParser(filename, delimiter=delimiter)
492
+ parser = StarParser(filename, delimiter=delimiter)
493
+
494
+ ret = parser.get("data_particles", None)
495
+ if ret is None:
496
+ ret = parser.get("data_", None)
496
497
 
497
- ret = ret.get("data_particles", None)
498
498
  if ret is None:
499
499
  raise ValueError(f"No data_particles section found in {filename}.")
500
500
 
@@ -503,13 +503,21 @@ class Orientations:
503
503
  )
504
504
  translation = translation.astype(np.float32).T
505
505
 
506
+ default_angle = np.zeros(translation.shape[0], dtype=np.float32)
507
+ for x in ("_rlnAngleRot", "_rlnAngleTilt", "_rlnAnglePsi"):
508
+ if x not in ret:
509
+ ret[x] = default_angle
510
+
506
511
  rotation = np.vstack(
507
512
  (ret["_rlnAngleRot"], ret["_rlnAngleTilt"], ret["_rlnAnglePsi"])
508
513
  )
509
514
  rotation = rotation.astype(np.float32).T
510
515
 
511
516
  default = np.zeros(translation.shape[0])
512
- return translation, rotation, default, default
517
+
518
+ details = ret.get("_rlnClassNumber", default)
519
+ scores = ret.get("_pytmeScore", default)
520
+ return translation, rotation, scores, details
513
521
 
514
522
  @staticmethod
515
523
  def _from_tbl(