pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.dev20250731__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. pytme-0.3.1.dev20250731.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/match_template.py +30 -41
  3. {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/postprocess.py +35 -21
  4. {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocessor_gui.py +96 -24
  5. pytme-0.3.1.dev20250731.data/scripts/pytme_runner.py +1223 -0
  6. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/METADATA +5 -7
  7. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/RECORD +59 -49
  8. scripts/estimate_ram_usage.py +97 -0
  9. scripts/extract_candidates.py +118 -99
  10. scripts/match_template.py +30 -41
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +35 -21
  13. scripts/preprocessor_gui.py +96 -24
  14. scripts/pytme_runner.py +644 -190
  15. scripts/refine_matches.py +158 -390
  16. tests/data/.DS_Store +0 -0
  17. tests/data/Blurring/.DS_Store +0 -0
  18. tests/data/Maps/.DS_Store +0 -0
  19. tests/data/Raw/.DS_Store +0 -0
  20. tests/data/Structures/.DS_Store +0 -0
  21. tests/preprocessing/test_utils.py +18 -0
  22. tests/test_analyzer.py +2 -3
  23. tests/test_backends.py +3 -9
  24. tests/test_density.py +0 -1
  25. tests/test_extensions.py +0 -1
  26. tests/test_matching_utils.py +10 -60
  27. tests/test_orientations.py +0 -12
  28. tests/test_rotations.py +1 -1
  29. tme/__version__.py +1 -1
  30. tme/analyzer/_utils.py +4 -4
  31. tme/analyzer/aggregation.py +35 -15
  32. tme/analyzer/peaks.py +11 -10
  33. tme/backends/_jax_utils.py +64 -18
  34. tme/backends/_numpyfftw_utils.py +270 -0
  35. tme/backends/cupy_backend.py +16 -55
  36. tme/backends/jax_backend.py +79 -40
  37. tme/backends/matching_backend.py +17 -51
  38. tme/backends/mlx_backend.py +1 -27
  39. tme/backends/npfftw_backend.py +71 -65
  40. tme/backends/pytorch_backend.py +1 -26
  41. tme/density.py +58 -5
  42. tme/extensions.cpython-311-darwin.so +0 -0
  43. tme/filters/ctf.py +22 -21
  44. tme/filters/wedge.py +10 -7
  45. tme/mask.py +341 -0
  46. tme/matching_data.py +31 -19
  47. tme/matching_exhaustive.py +37 -47
  48. tme/matching_optimization.py +2 -1
  49. tme/matching_scores.py +229 -411
  50. tme/matching_utils.py +73 -422
  51. tme/memory.py +1 -1
  52. tme/orientations.py +24 -13
  53. tme/rotations.py +1 -1
  54. pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
  55. {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/estimate_memory_usage.py +0 -0
  56. {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocess.py +0 -0
  57. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/WHEEL +0 -0
  58. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/entry_points.txt +0 -0
  59. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/licenses/LICENSE +0 -0
  60. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/top_level.txt +0 -0
tme/matching_utils.py CHANGED
@@ -12,40 +12,36 @@ from shutil import move
12
12
  from joblib import Parallel
13
13
  from tempfile import mkstemp
14
14
  from itertools import product
15
+ from gzip import open as gzip_open
16
+ from typing import Tuple, Dict, Callable
15
17
  from concurrent.futures import ThreadPoolExecutor
16
- from typing import Tuple, Dict, Callable, Optional
17
18
 
18
19
  import numpy as np
19
20
  from scipy.spatial import ConvexHull
20
- from scipy.ndimage import gaussian_filter
21
21
 
22
22
  from .backends import backend as be
23
23
  from .memory import estimate_memory_usage
24
24
  from .types import NDArray, BackendArray
25
25
 
26
26
 
27
- def noop(*args, **kwargs):
28
- pass
29
-
30
-
31
- def identity(arr, *args):
27
+ def identity(arr, *args, **kwargs):
32
28
  return arr
33
29
 
34
30
 
35
31
  def conditional_execute(
36
32
  func: Callable,
37
- execute_operation: bool,
38
- alt_func: Callable = noop,
33
+ execute_operation: bool = False,
34
+ alt_func: Callable = identity,
39
35
  ) -> Callable:
40
36
  """
41
- Return the given function or a no-op function based on execute_operation.
37
+ Return the given function or alternative function based on execute_operation.
42
38
 
43
39
  Parameters
44
40
  ----------
45
41
  func : Callable
46
42
  Callable.
47
43
  alt_func : Callable
48
- Callable to return if ``execute_operation`` is False, no-op by default.
44
+ Callable to return if ``execute_operation`` is False, identity by default.
49
45
  execute_operation : bool
50
46
  Whether to return ``func`` or a ``alt_func`` function.
51
47
 
@@ -175,6 +171,12 @@ def memmap_to_array(arr: NDArray) -> NDArray:
175
171
  return arr
176
172
 
177
173
 
174
+ def is_gzipped(filename: str) -> bool:
175
+ """Check if a file is a gzip file by reading its magic number."""
176
+ with open(filename, "rb") as f:
177
+ return f.read(2) == b"\x1f\x8b"
178
+
179
+
178
180
  def write_pickle(data: object, filename: str) -> None:
179
181
  """
180
182
  Serialize and write data to a file invalidating the input data.
@@ -242,7 +244,11 @@ def load_pickle(filename: str) -> object:
242
244
  return ret
243
245
 
244
246
  items = []
245
- with open(filename, "rb") as ifile:
247
+ func = open
248
+ if is_gzipped(filename):
249
+ func = gzip_open
250
+
251
+ with func(filename, "rb") as ifile:
246
252
  for data in _load_pickle(ifile):
247
253
  if isinstance(data, tuple):
248
254
  if _is_pickle_memmap(data):
@@ -409,7 +415,7 @@ def compute_parallelization_schedule(
409
415
  return splits, core_assignment
410
416
 
411
417
 
412
- def _center_slice(current_shape: Tuple[int], new_shape: Tuple[int]) -> Tuple[slice]:
418
+ def center_slice(current_shape: Tuple[int], new_shape: Tuple[int]) -> Tuple[slice]:
413
419
  """Extract the center slice of ``current_shape`` to retrieve ``new_shape``."""
414
420
  new_shape = tuple(int(x) for x in new_shape)
415
421
  current_shape = tuple(int(x) for x in current_shape)
@@ -419,60 +425,12 @@ def _center_slice(current_shape: Tuple[int], new_shape: Tuple[int]) -> Tuple[sli
419
425
  return box
420
426
 
421
427
 
422
- def centered(arr: BackendArray, new_shape: Tuple[int]) -> BackendArray:
423
- """
424
- Extract the centered portion of an array based on a new shape.
425
-
426
- Parameters
427
- ----------
428
- arr : BackendArray
429
- Input data.
430
- new_shape : tuple of ints
431
- Desired shape for the central portion.
432
-
433
- Returns
434
- -------
435
- BackendArray
436
- Central portion of the array with shape ``new_shape``.
437
-
438
- References
439
- ----------
440
- .. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py#L388
441
- """
442
- box = _center_slice(arr.shape, new_shape=new_shape)
443
- return arr[box]
444
-
445
-
446
- def centered_mask(arr: BackendArray, new_shape: Tuple[int]) -> BackendArray:
447
- """
448
- Mask the centered portion of an array based on a new shape.
449
-
450
- Parameters
451
- ----------
452
- arr : BackendArray
453
- Input data.
454
- new_shape : tuple of ints
455
- Desired shape for the mask.
456
-
457
- Returns
458
- -------
459
- BackendArray
460
- Array with central portion unmasked and the rest set to 0.
461
- """
462
- box = _center_slice(arr.shape, new_shape=new_shape)
463
- mask = np.zeros_like(arr)
464
- mask[box] = 1
465
- arr *= mask
466
- return arr
467
-
468
-
469
428
  def apply_convolution_mode(
470
429
  arr: BackendArray,
471
430
  convolution_mode: str,
472
431
  s1: Tuple[int],
473
432
  s2: Tuple[int],
474
433
  convolution_shape: Tuple[int] = None,
475
- mask_output: bool = False,
476
434
  ) -> BackendArray:
477
435
  """
478
436
  Applies convolution_mode to ``arr``.
@@ -497,9 +455,6 @@ def apply_convolution_mode(
497
455
  Tuple of integers corresponding to shape of convolution array 2.
498
456
  convolution_shape : tuple of ints, optional
499
457
  Size of the actually computed convolution. s1 + s2 - 1 by default.
500
- mask_output : bool, optional
501
- Whether to mask values outside of convolution_mode rather than
502
- removing them. Defaults to False.
503
458
 
504
459
  Returns
505
460
  -------
@@ -514,14 +469,13 @@ def apply_convolution_mode(
514
469
  if convolution_mode not in ("full", "same", "valid"):
515
470
  raise ValueError("Supported convolution_mode are 'full', 'same' and 'valid'.")
516
471
 
517
- func = centered_mask if mask_output else centered
518
472
  if convolution_mode == "full":
519
- return arr
473
+ subset = ...
520
474
  elif convolution_mode == "same":
521
- return func(arr, s1)
475
+ subset = center_slice(arr.shape, s1)
522
476
  elif convolution_mode == "valid":
523
- valid_shape = [s1[i] - s2[i] + 1 for i in range(arr.ndim)]
524
- return func(arr, valid_shape)
477
+ subset = center_slice(arr.shape, [x - y + 1 for x, y in zip(s1, s2)])
478
+ return arr[subset]
525
479
 
526
480
 
527
481
  def compute_full_convolution_index(
@@ -716,350 +670,8 @@ def minimum_enclosing_box(
716
670
  return shape
717
671
 
718
672
 
719
- def create_mask(mask_type: str, sigma_decay: float = 0, **kwargs) -> NDArray:
720
- """
721
- Creates a mask of the specified type.
722
-
723
- Parameters
724
- ----------
725
- mask_type : str
726
- Type of the mask to be created. Can be one of:
727
-
728
- +---------+----------------------------------------------------------+
729
- | box | Box mask (see :py:meth:`box_mask`) |
730
- +---------+----------------------------------------------------------+
731
- | tube | Cylindrical mask (see :py:meth:`tube_mask`) |
732
- +---------+----------------------------------------------------------+
733
- | ellipse | Ellipsoidal mask (see :py:meth:`elliptical_mask`) |
734
- +---------+----------------------------------------------------------+
735
- sigma_decay : float, optional
736
- Smoothing along mask edges using a Gaussian filter, 0 by default.
737
- kwargs : dict
738
- Parameters passed to the indivdual mask creation funcitons.
739
-
740
- Returns
741
- -------
742
- NDArray
743
- The created mask.
744
-
745
- Raises
746
- ------
747
- ValueError
748
- If the mask_type is invalid.
749
- """
750
- mapping = {"ellipse": elliptical_mask, "box": box_mask, "tube": tube_mask}
751
- if mask_type not in mapping:
752
- raise ValueError(f"mask_type has to be one of {','.join(mapping.keys())}")
753
-
754
- mask = mapping[mask_type](**kwargs)
755
- if sigma_decay > 0:
756
- mask_filter = gaussian_filter(mask.astype(np.float32), sigma=sigma_decay)
757
- mask = np.add(mask, (1 - mask) * mask_filter)
758
- mask[mask < np.exp(-np.square(sigma_decay))] = 0
759
-
760
- return mask
761
-
762
-
763
- def elliptical_mask(
764
- shape: Tuple[int],
765
- radius: Tuple[float],
766
- center: Optional[Tuple[float]] = None,
767
- orientation: Optional[NDArray] = None,
768
- ) -> NDArray:
769
- """
770
- Creates an ellipsoidal mask.
771
-
772
- Parameters
773
- ----------
774
- shape : tuple of ints
775
- Shape of the mask to be created.
776
- radius : tuple of floats
777
- Radius of the mask.
778
- center : tuple of floats, optional
779
- Center of the mask, default to shape // 2.
780
- orientation : NDArray, optional.
781
- Orientation of the mask as rotation matrix with shape (d,d).
782
-
783
- Returns
784
- -------
785
- NDArray
786
- The created ellipsoidal mask.
787
-
788
- Raises
789
- ------
790
- ValueError
791
- If the length of center and radius is not one or the same as shape.
792
-
793
- Examples
794
- --------
795
- >>> from tme.matching_utils import elliptical_mask
796
- >>> mask = elliptical_mask(shape=(20,20), radius=(5,5), center=(10,10))
797
- """
798
- shape, radius = np.asarray(shape), np.asarray(radius)
799
-
800
- shape = shape.astype(int)
801
- if center is None:
802
- center = np.divide(shape, 2).astype(int)
803
-
804
- center = np.asarray(center, dtype=np.float32)
805
- radius = np.repeat(radius, shape.size // radius.size)
806
- center = np.repeat(center, shape.size // center.size)
807
- if radius.size != shape.size:
808
- raise ValueError("Length of radius has to be either one or match shape.")
809
- if center.size != shape.size:
810
- raise ValueError("Length of center has to be either one or match shape.")
811
-
812
- n = shape.size
813
- center = center.reshape((-1,) + (1,) * n)
814
- radius = radius.reshape((-1,) + (1,) * n)
815
-
816
- indices = np.indices(shape, dtype=np.float32) - center
817
- if orientation is not None:
818
- return_shape = indices.shape
819
- indices = indices.reshape(n, -1)
820
- rigid_transform(
821
- coordinates=indices,
822
- rotation_matrix=np.asarray(orientation),
823
- out=indices,
824
- translation=np.zeros(n),
825
- use_geometric_center=False,
826
- )
827
- indices = indices.reshape(*return_shape)
828
-
829
- mask = np.linalg.norm(indices / radius, axis=0)
830
- mask = (mask <= 1).astype(int)
831
-
832
- return mask
833
-
834
-
835
- def tube_mask2(
836
- shape: Tuple[int],
837
- inner_radius: float,
838
- outer_radius: float,
839
- height: int,
840
- symmetry_axis: Optional[int] = 2,
841
- center: Optional[Tuple[float]] = None,
842
- orientation: Optional[NDArray] = None,
843
- epsilon: float = 0.5,
844
- ) -> NDArray:
845
- """
846
- Creates a tube mask.
847
-
848
- Parameters
849
- ----------
850
- shape : tuple
851
- Shape of the mask to be created.
852
- inner_radius : float
853
- Inner radius of the tube.
854
- outer_radius : float
855
- Outer radius of the tube.
856
- height : int
857
- Height of the tube.
858
- symmetry_axis : int, optional
859
- The axis of symmetry for the tube, defaults to 2.
860
- center : tuple of float, optional.
861
- Center of the mask, defaults to shape // 2.
862
- orientation : NDArray, optional.
863
- Orientation of the mask as rotation matrix with shape (d,d).
864
- epsilon : float, optional
865
- Tolerance to handle discretization errors, defaults to 0.5.
866
-
867
- Returns
868
- -------
869
- NDArray
870
- The created tube mask.
871
-
872
- Raises
873
- ------
874
- ValueError
875
- If ``inner_radius`` is larger than ``outer_radius``.
876
- If ``center`` and ``shape`` do not have the same length.
877
- """
878
- shape = np.asarray(shape, dtype=int)
879
-
880
- if center is None:
881
- center = np.divide(shape, 2).astype(int)
882
-
883
- center = np.asarray(center, dtype=np.float32)
884
- center = np.repeat(center, shape.size // center.size)
885
- if inner_radius > outer_radius:
886
- raise ValueError("inner_radius should be smaller than outer_radius.")
887
- if symmetry_axis > len(shape):
888
- raise ValueError(f"symmetry_axis can be not larger than {len(shape)}.")
889
- if center.size != shape.size:
890
- raise ValueError("Length of center has to be either one or match shape.")
891
-
892
- n = shape.size
893
- center = center.reshape((-1,) + (1,) * n)
894
- indices = np.indices(shape, dtype=np.float32) - center
895
- if orientation is not None:
896
- return_shape = indices.shape
897
- indices = indices.reshape(n, -1)
898
- rigid_transform(
899
- coordinates=indices,
900
- rotation_matrix=np.asarray(orientation),
901
- out=indices,
902
- translation=np.zeros(n),
903
- use_geometric_center=False,
904
- )
905
- indices = indices.reshape(*return_shape)
906
-
907
- mask = np.zeros(shape, dtype=bool)
908
- sq_dist = np.zeros(shape)
909
- for i in range(len(shape)):
910
- if i == symmetry_axis:
911
- continue
912
- sq_dist += indices[i] ** 2
913
-
914
- sym_coord = indices[symmetry_axis]
915
- half_height = height / 2
916
- height_mask = np.abs(sym_coord) <= half_height
917
-
918
- inner_mask = 1
919
- if inner_radius > epsilon:
920
- inner_mask = sq_dist >= ((inner_radius) ** 2 - epsilon)
921
-
922
- height_mask = np.abs(sym_coord) <= (half_height + epsilon)
923
- outer_mask = sq_dist <= ((outer_radius) ** 2 + epsilon)
924
-
925
- mask = height_mask & inner_mask & outer_mask
926
- return mask
927
-
928
-
929
- def box_mask(shape: Tuple[int], center: Tuple[int], height: Tuple[int]) -> np.ndarray:
930
- """
931
- Creates a box mask centered around the provided center point.
932
-
933
- Parameters
934
- ----------
935
- shape : tuple of ints
936
- Shape of the output array.
937
- center : tuple of ints
938
- Center point coordinates of the box.
939
- height : tuple of ints
940
- Height (side length) of the box along each axis.
941
-
942
- Returns
943
- -------
944
- NDArray
945
- The created box mask.
946
-
947
- Raises
948
- ------
949
- ValueError
950
- If ``shape`` and ``center`` do not have the same length.
951
- If ``center`` and ``height`` do not have the same length.
952
- """
953
- if len(shape) != len(center) or len(center) != len(height):
954
- raise ValueError("The length of shape, center, and height must be consistent.")
955
-
956
- shape = tuple(int(x) for x in shape)
957
- center, height = np.array(center, dtype=int), np.array(height, dtype=int)
958
-
959
- half_heights = height // 2
960
- starts = np.maximum(center - half_heights, 0)
961
- stops = np.minimum(center + half_heights + np.mod(height, 2) + 1, shape)
962
- slice_indices = tuple(slice(*coord) for coord in zip(starts, stops))
963
-
964
- out = np.zeros(shape)
965
- out[slice_indices] = 1
966
- return out
967
-
968
-
969
- def tube_mask(
970
- shape: Tuple[int],
971
- symmetry_axis: int,
972
- base_center: Tuple[int],
973
- inner_radius: float,
974
- outer_radius: float,
975
- height: int,
976
- ) -> NDArray:
977
- """
978
- Creates a tube mask.
979
-
980
- Parameters
981
- ----------
982
- shape : tuple
983
- Shape of the mask to be created.
984
- symmetry_axis : int
985
- The axis of symmetry for the tube.
986
- base_center : tuple
987
- Center of the tube.
988
- inner_radius : float
989
- Inner radius of the tube.
990
- outer_radius : float
991
- Outer radius of the tube.
992
- height : int
993
- Height of the tube.
994
-
995
- Returns
996
- -------
997
- NDArray
998
- The created tube mask.
999
-
1000
- Raises
1001
- ------
1002
- ValueError
1003
- If ``inner_radius`` is larger than ``outer_radius``.
1004
- If ``height`` is larger than the symmetry axis.
1005
- If ``base_center`` and ``shape`` do not have the same length.
1006
- """
1007
- if inner_radius > outer_radius:
1008
- raise ValueError("inner_radius should be smaller than outer_radius.")
1009
-
1010
- if height > shape[symmetry_axis]:
1011
- raise ValueError(f"Height can be no larger than {shape[symmetry_axis]}.")
1012
-
1013
- if symmetry_axis > len(shape):
1014
- raise ValueError(f"symmetry_axis can be not larger than {len(shape)}.")
1015
-
1016
- if len(base_center) != len(shape):
1017
- raise ValueError("shape and base_center need to have the same length.")
1018
-
1019
- shape = tuple(int(x) for x in shape)
1020
- circle_shape = tuple(b for ix, b in enumerate(shape) if ix != symmetry_axis)
1021
- circle_center = tuple(b for ix, b in enumerate(base_center) if ix != symmetry_axis)
1022
-
1023
- inner_circle = np.zeros(circle_shape)
1024
- outer_circle = np.zeros_like(inner_circle)
1025
- if inner_radius > 0:
1026
- inner_circle = create_mask(
1027
- mask_type="ellipse",
1028
- shape=circle_shape,
1029
- radius=inner_radius,
1030
- center=circle_center,
1031
- )
1032
- if outer_radius > 0:
1033
- outer_circle = create_mask(
1034
- mask_type="ellipse",
1035
- shape=circle_shape,
1036
- radius=outer_radius,
1037
- center=circle_center,
1038
- )
1039
- circle = outer_circle - inner_circle
1040
- circle = np.expand_dims(circle, axis=symmetry_axis)
1041
-
1042
- center = base_center[symmetry_axis]
1043
- start_idx = int(center - height // 2)
1044
- stop_idx = int(center + height // 2 + height % 2)
1045
-
1046
- start_idx, stop_idx = max(start_idx, 0), min(stop_idx, shape[symmetry_axis])
1047
-
1048
- slice_indices = tuple(
1049
- slice(None) if i != symmetry_axis else slice(start_idx, stop_idx)
1050
- for i in range(len(shape))
1051
- )
1052
- tube = np.zeros(shape)
1053
- tube[slice_indices] = circle
1054
-
1055
- return tube
1056
-
1057
-
1058
673
  def scramble_phases(
1059
- arr: NDArray,
1060
- noise_proportion: float = 1.0,
1061
- seed: int = 42,
1062
- normalize_power: bool = False,
674
+ arr: NDArray, noise_proportion: float = 1.0, seed: int = 42, **kwargs
1063
675
  ) -> NDArray:
1064
676
  """
1065
677
  Perform random phase scrambling of ``arr``.
@@ -1095,16 +707,7 @@ def scramble_phases(
1095
707
 
1096
708
  ph_noise = np.random.permutation(ph[mask])
1097
709
  ph[mask] = ph[mask] * (1 - noise_proportion) + ph_noise * noise_proportion
1098
- ret = np.real(np.fft.ifftn(amp * np.exp(1j * ph)))
1099
-
1100
- if normalize_power:
1101
- np.divide(ret - ret.min(), ret.max() - ret.min(), out=ret)
1102
- np.multiply(ret, np.subtract(arr.max(), arr.min()), out=ret)
1103
- np.add(ret, arr.min(), out=ret)
1104
- scaling = np.divide(np.abs(arr).sum(), np.abs(ret).sum())
1105
- np.multiply(ret, scaling, out=ret)
1106
-
1107
- return ret
710
+ return np.real(np.fft.ifftn(amp * np.exp(1j * ph)))
1108
711
 
1109
712
 
1110
713
  def compute_extraction_box(
@@ -1160,6 +763,54 @@ def compute_extraction_box(
1160
763
  return obs_beg_clamp, obs_end_clamp, cand_beg, cand_end, keep
1161
764
 
1162
765
 
766
+ def create_mask(mask_type: str, sigma_decay: float = 0, **kwargs) -> NDArray:
767
+ """
768
+ Creates a mask of the specified type.
769
+
770
+ Parameters
771
+ ----------
772
+ mask_type : str
773
+ Type of the mask to be created. Can be one of:
774
+
775
+ +----------+---------------------------------------------------------+
776
+ | box | Box mask (see :py:meth:`box_mask`) |
777
+ +----------+---------------------------------------------------------+
778
+ | tube | Cylindrical mask (see :py:meth:`tube_mask`) |
779
+ +----------+---------------------------------------------------------+
780
+ | membrane | Cylindrical mask (see :py:meth:`membrane_mask`) |
781
+ +----------+---------------------------------------------------------+
782
+ | ellipse | Ellipsoidal mask (see :py:meth:`elliptical_mask`) |
783
+ +----------+---------------------------------------------------------+
784
+ sigma_decay : float, optional
785
+ Smoothing along mask edges using a Gaussian filter, 0 by default.
786
+ kwargs : dict
787
+ Parameters passed to the indivdual mask creation funcitons.
788
+
789
+ Returns
790
+ -------
791
+ NDArray
792
+ The created mask.
793
+
794
+ Raises
795
+ ------
796
+ ValueError
797
+ If the mask_type is invalid.
798
+ """
799
+ from .mask import elliptical_mask, box_mask, tube_mask, membrane_mask
800
+
801
+ mapping = {
802
+ "ellipse": elliptical_mask,
803
+ "box": box_mask,
804
+ "tube": tube_mask,
805
+ "membrane": membrane_mask,
806
+ }
807
+ if mask_type not in mapping:
808
+ raise ValueError(f"mask_type has to be one of {','.join(mapping.keys())}")
809
+
810
+ mask = mapping[mask_type](**kwargs, sigma_decay=sigma_decay)
811
+ return mask
812
+
813
+
1163
814
  class TqdmParallel(Parallel):
1164
815
  """
1165
816
  A minimal Parallel implementation using tqdm for progress reporting.
tme/memory.py CHANGED
@@ -244,7 +244,7 @@ MATCHING_MEMORY_REGISTRY = {
244
244
  "PeakCallerMaximumFilter": PeakCallerMaximumFilterMemoryUsage,
245
245
  "cupy": CupyBackendMemoryUsage,
246
246
  "pytorch": CupyBackendMemoryUsage,
247
- "batchFLCSpherical": FLCSphericalMaskMemoryUsage,
247
+ "batchFLCSphericalMask": FLCSphericalMaskMemoryUsage,
248
248
  "batchFLC": FLCMemoryUsage,
249
249
  }
250
250
 
tme/orientations.py CHANGED
@@ -82,7 +82,7 @@ class Orientations:
82
82
  self.translations = np.array(self.translations).astype(np.float32)
83
83
  self.rotations = np.array(self.rotations).astype(np.float32)
84
84
  self.scores = np.array(self.scores).astype(np.float32)
85
- self.details = np.array(self.details)
85
+ self.details = np.array(self.details).astype(np.float32)
86
86
  n_orientations = set(
87
87
  [
88
88
  self.translations.shape[0],
@@ -327,11 +327,18 @@ class Orientations:
327
327
  "_rlnAnglePsi",
328
328
  "_rlnClassNumber",
329
329
  ]
330
+
331
+ target_identifer = "_rlnMicrographName"
332
+ if version == "# version 50001":
333
+ header[3] = "_rlnCenteredCoordinateXAngst"
334
+ header[4] = "_rlnCenteredCoordinateYAngst"
335
+ header[5] = "_rlnCenteredCoordinateZAngst"
336
+ target_identifer = "_rlnTomoName"
337
+
330
338
  if source_path is not None:
331
- header.append("_rlnMicrographName")
339
+ header.append(target_identifer)
332
340
 
333
341
  header.append("_pytmeScore")
334
-
335
342
  header = "\n".join(header)
336
343
  with open(filename, mode="w", encoding="utf-8") as ofile:
337
344
  if version is not None:
@@ -471,10 +478,7 @@ class Orientations:
471
478
  rotation = np.zeros(translation.shape, dtype=np.float32)
472
479
 
473
480
  header_order = tuple(x for x in header if x in NAMES)
474
- header_order = zip(header_order, range(len(header_order)))
475
- sort_order = tuple(
476
- x[1] for x in sorted(header_order, key=lambda x: x[0], reverse=False)
477
- )
481
+ sort_order = tuple(NAMES.index(x) for x in header_order)
478
482
  translation = translation[..., sort_order]
479
483
 
480
484
  header_order = tuple(
@@ -490,16 +494,22 @@ class Orientations:
490
494
 
491
495
  @classmethod
492
496
  def _from_star(
493
- cls, filename: str, delimiter: str = "\t"
497
+ cls, filename: str, delimiter: str = None
494
498
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
495
499
  parser = StarParser(filename, delimiter=delimiter)
496
500
 
497
- ret = parser.get("data_particles", None)
498
- if ret is None:
499
- ret = parser.get("data_", None)
501
+ keyword_order = ("data_particles", "particles", "data")
502
+ for keyword in keyword_order:
503
+ ret = parser.get(keyword, None)
504
+ if ret is None:
505
+ ret = parser.get(f"{keyword}_", None)
506
+ if ret is not None:
507
+ break
500
508
 
501
509
  if ret is None:
502
- raise ValueError(f"No data_particles section found in {filename}.")
510
+ raise ValueError(
511
+ f"Could not find either {keyword_order} section found in {filename}."
512
+ )
503
513
 
504
514
  translation = np.vstack(
505
515
  (ret["_rlnCoordinateX"], ret["_rlnCoordinateY"], ret["_rlnCoordinateZ"])
@@ -518,8 +528,9 @@ class Orientations:
518
528
 
519
529
  default = np.zeros(translation.shape[0])
520
530
 
531
+ details = ret.get("_rlnClassNumber", default)
521
532
  scores = ret.get("_pytmeScore", default)
522
- return translation, rotation, scores, default
533
+ return translation, rotation, scores, details
523
534
 
524
535
  @staticmethod
525
536
  def _from_tbl(
tme/rotations.py CHANGED
@@ -250,7 +250,7 @@ def get_rotation_matrices(
250
250
  """
251
251
  if dim == 3 and use_optimized_set:
252
252
  quaternions, *_ = _load_quaternions_by_angle(angular_sampling)
253
- ret = Rotation.from_quat(quaternions).as_matrix()
253
+ ret = Rotation.from_quat(quaternions, scalar_first=True).as_matrix()
254
254
  else:
255
255
  num_rotations = dim * (dim - 1) // 2
256
256
  k = int((360 / angular_sampling) ** num_rotations)