pytme 0.2.9.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. pytme-0.3b0.post1.data/scripts/estimate_memory_usage.py +76 -0
  2. pytme-0.3b0.post1.data/scripts/match_template.py +1098 -0
  3. {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/postprocess.py +318 -189
  4. {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/preprocess.py +21 -31
  5. {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/preprocessor_gui.py +12 -12
  6. pytme-0.3b0.post1.data/scripts/pytme_runner.py +769 -0
  7. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/METADATA +21 -20
  8. pytme-0.3b0.post1.dist-info/RECORD +126 -0
  9. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/entry_points.txt +2 -1
  10. pytme-0.3b0.post1.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +76 -0
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +224 -0
  14. scripts/match_template.py +341 -378
  15. pytme-0.2.9.post1.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
  16. scripts/postprocess.py +318 -189
  17. scripts/preprocess.py +21 -31
  18. scripts/preprocessor_gui.py +12 -12
  19. scripts/pytme_runner.py +769 -0
  20. scripts/refine_matches.py +625 -0
  21. tests/preprocessing/test_frequency_filters.py +28 -14
  22. tests/test_analyzer.py +41 -36
  23. tests/test_backends.py +1 -0
  24. tests/test_matching_cli.py +109 -54
  25. tests/test_matching_data.py +5 -5
  26. tests/test_matching_exhaustive.py +1 -2
  27. tests/test_matching_optimization.py +4 -9
  28. tests/test_matching_utils.py +1 -1
  29. tests/test_orientations.py +0 -1
  30. tme/__version__.py +1 -1
  31. tme/analyzer/__init__.py +2 -0
  32. tme/analyzer/_utils.py +26 -21
  33. tme/analyzer/aggregation.py +395 -222
  34. tme/analyzer/base.py +127 -0
  35. tme/analyzer/peaks.py +189 -204
  36. tme/analyzer/proxy.py +123 -0
  37. tme/backends/__init__.py +4 -3
  38. tme/backends/_cupy_utils.py +25 -24
  39. tme/backends/_jax_utils.py +20 -18
  40. tme/backends/cupy_backend.py +13 -26
  41. tme/backends/jax_backend.py +24 -23
  42. tme/backends/matching_backend.py +4 -3
  43. tme/backends/mlx_backend.py +4 -3
  44. tme/backends/npfftw_backend.py +34 -30
  45. tme/backends/pytorch_backend.py +18 -4
  46. tme/cli.py +126 -0
  47. tme/density.py +9 -7
  48. tme/filters/__init__.py +3 -3
  49. tme/filters/_utils.py +36 -10
  50. tme/filters/bandpass.py +229 -188
  51. tme/filters/compose.py +5 -4
  52. tme/filters/ctf.py +516 -254
  53. tme/filters/reconstruction.py +91 -32
  54. tme/filters/wedge.py +196 -135
  55. tme/filters/whitening.py +37 -42
  56. tme/matching_data.py +28 -39
  57. tme/matching_exhaustive.py +31 -27
  58. tme/matching_optimization.py +5 -4
  59. tme/matching_scores.py +25 -15
  60. tme/matching_utils.py +54 -9
  61. tme/memory.py +4 -3
  62. tme/orientations.py +22 -9
  63. tme/parser.py +114 -33
  64. tme/preprocessor.py +6 -5
  65. tme/rotations.py +10 -7
  66. tme/structure.py +4 -3
  67. pytme-0.2.9.post1.data/scripts/estimate_ram_usage.py +0 -97
  68. pytme-0.2.9.post1.dist-info/RECORD +0 -119
  69. pytme-0.2.9.post1.dist-info/licenses/LICENSE +0 -153
  70. scripts/estimate_ram_usage.py +0 -97
  71. tests/data/Maps/.DS_Store +0 -0
  72. tests/data/Structures/.DS_Store +0 -0
  73. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/WHEEL +0 -0
  74. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/top_level.txt +0 -0
tme/cli.py ADDED
@@ -0,0 +1,126 @@
1
+ #!python3
2
+ """
3
+ CLI utility functions.
4
+
5
+ Copyright (c) 2025 European Molecular Biology Laboratory
6
+
7
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
8
+ """
9
+
10
+ import argparse
11
+
12
+ import numpy as np
13
+ from . import __version__
14
+ from .types import BackendArray
15
+
16
+
17
+ def match_template(
18
+ target: BackendArray,
19
+ template: BackendArray,
20
+ template_mask: BackendArray = None,
21
+ score="FLCSphericalMask",
22
+ rotations=None,
23
+ ):
24
+ """
25
+ Simple template matching run.
26
+
27
+ Parameters
28
+ ----------
29
+ target : BackendArray
30
+ Target array.
31
+ template : BackendArray
32
+ Template to be matched against target.
33
+ template_mask : BackendArray, optional
34
+ Template mask for normalization, defaults to None.
35
+ score : str, optional
36
+ Scoring method to use, defaults to 'FLCSphericalMask'.
37
+ rotations: BackendArray, optional
38
+ Rotation matrices with shape (n, d, d), where d is the dimension
39
+ of the target. Defaults to the identity rotation matrix.
40
+
41
+ Returns
42
+ -------
43
+ tuple
44
+ scores : BackendArray
45
+ Computed cross-correlation scores.
46
+ offset : BackendArray
47
+ Offset in target, defaults to 0.
48
+ rotations : BackendArray
49
+ Map between translations and rotation indices
50
+ rotation_mapping : dict
51
+ Map between rotation indices and rotation matrices
52
+ """
53
+ from .matching_data import MatchingData
54
+ from .analyzer import MaxScoreOverRotations
55
+ from .matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
56
+
57
+ if rotations is None:
58
+ rotations = np.eye(target.ndim).reshape(1, target.ndim, target.ndim)
59
+
60
+ if rotations.shape[-1] != target.ndim:
61
+ print(
62
+ f"Dimension of rotation matrix {rotations.shape[-1]} does not "
63
+ "match target dimension."
64
+ )
65
+
66
+ matching_data = MatchingData(
67
+ target=target,
68
+ template=template,
69
+ template_mask=template_mask,
70
+ rotations=rotations,
71
+ )
72
+ matching_data.template_mask = template_mask
73
+ matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[score]
74
+
75
+ candidates = list(
76
+ scan_subsets(
77
+ matching_data=matching_data,
78
+ matching_score=matching_score,
79
+ matching_setup=matching_setup,
80
+ callback_class=MaxScoreOverRotations,
81
+ callback_class_args={
82
+ "score_threshold": -1,
83
+ },
84
+ pad_target_edges=True,
85
+ job_schedule=(1, 1),
86
+ )
87
+ )
88
+ return candidates
89
+
90
+
91
+ def sanitize_name(name: str):
92
+ return name.title().replace("_", " ").replace("-", " ")
93
+
94
+
95
+ def print_entry() -> None:
96
+ width = 80
97
+ text = f" pytme v{__version__} "
98
+ padding_total = width - len(text) - 2
99
+ padding_left = padding_total // 2
100
+ padding_right = padding_total - padding_left
101
+
102
+ print("*" * width)
103
+ print(f"*{ ' ' * padding_left }{text}{ ' ' * padding_right }*")
104
+ print("*" * width)
105
+
106
+
107
+ def get_func_fullname(func) -> str:
108
+ """Returns the full name of the given function, including its module."""
109
+ return f"<function '{func.__module__}.{func.__name__}'>"
110
+
111
+
112
+ def print_block(name: str, data: dict, label_width=20) -> None:
113
+ """Prints a formatted block of information."""
114
+ print(f"\n> {name}")
115
+ for key, value in data.items():
116
+ if isinstance(value, np.ndarray):
117
+ value = value.shape
118
+ formatted_value = str(value)
119
+ print(f" - {str(key) + ':':<{label_width}} {formatted_value}")
120
+
121
+
122
+ def check_positive(value):
123
+ ivalue = float(value)
124
+ if ivalue <= 0:
125
+ raise argparse.ArgumentTypeError("%s is an invalid positive float." % value)
126
+ return ivalue
tme/density.py CHANGED
@@ -1,8 +1,9 @@
1
- """ Representation of N-dimensional densities
1
+ """
2
+ Representation of N-dimensional densities
2
3
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  import warnings
@@ -1762,12 +1763,13 @@ class Density:
1762
1763
  axis=axis,
1763
1764
  )
1764
1765
 
1765
- arr_ft = np.fft.fftn(self.data)
1766
+ mask, mask_ret = np.where(mask), np.where(mask_ret)
1767
+
1768
+ arr_ft = np.fft.fftn(self.data)[mask]
1766
1769
  arr_ft *= np.prod(ret_shape) / np.prod(self.shape)
1767
1770
  ret_ft = np.zeros(ret_shape, dtype=arr_ft.dtype)
1768
- ret_ft[mask_ret] = arr_ft[mask]
1769
- ret.data = np.real(np.fft.ifftn(ret_ft))
1770
-
1771
+ np.add.at(ret_ft, mask_ret, arr_ft)
1772
+ ret.data = np.real(np.fft.ifftn(ret_ft)).astype(self.data.dtype)
1771
1773
  ret.sampling_rate = new_sampling_rate
1772
1774
  return ret
1773
1775
 
tme/filters/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
- from .ctf import CTF
1
+ from .ctf import CTF, CTFReconstructed
2
2
  from .compose import Compose, ComposableFilter
3
- from .bandpass import BandPassFilter
3
+ from .bandpass import BandPass, BandPassReconstructed
4
4
  from .whitening import LinearWhiteningFilter
5
5
  from .wedge import Wedge, WedgeReconstructed
6
- from .reconstruction import ReconstructFromTilt
6
+ from .reconstruction import ReconstructFromTilt, ShiftFourier
tme/filters/_utils.py CHANGED
@@ -1,8 +1,9 @@
1
- """ Utilities for the generation of frequency grids.
1
+ """
2
+ Utilities for the generation of frequency grids.
2
3
 
3
- Copyright (c) 2024 European Molecular Biology Laboratory
4
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  from typing import Tuple, List, Dict
@@ -14,6 +15,18 @@ from ..backends import NumpyFFTWBackend
14
15
  from ..types import BackendArray, NDArray
15
16
  from ..rotations import euler_to_rotationmatrix
16
17
 
18
+ __all__ = [
19
+ "compute_tilt_shape",
20
+ "centered_grid",
21
+ "frequency_grid_at_angle",
22
+ "fftfreqn",
23
+ "crop_real_fourier",
24
+ "compute_fourier_shape",
25
+ "shift_fourier",
26
+ "create_reconstruction_filter",
27
+ "pad_to_length",
28
+ ]
29
+
17
30
 
18
31
  def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool = False):
19
32
  """
@@ -70,21 +83,27 @@ def frequency_grid_at_angle(
70
83
  """
71
84
  Generate a frequency grid from 0 to 1/(2 * sampling_rate) in each axis.
72
85
 
86
+ Conceptually, this function generates accurate frequency grid of tilted
87
+ projections. Given a non-cubical shape, it no longer accurate to compute
88
+ frequences as Euclidean distances from a centered index grid. This function
89
+ solves this issue, and makes it possible to create complex filters on
90
+ non-cubical input shapes.
91
+
73
92
  Parameters
74
93
  ----------
75
- shape : Tuple[int]
94
+ shape : tuple of int
76
95
  The shape of the grid.
77
96
  angle : float
78
97
  The angle at which to generate the grid.
79
- sampling_rate : Tuple[float]
98
+ sampling_rate : tuple of float
80
99
  The sampling rate for each dimension.
81
100
  opening_axis : int, optional
82
- The axis to be opened, defaults to None.
101
+ The projection axis, defaults to None.
83
102
  tilt_axis : int, optional
84
103
  The axis along which the grid is tilted, defaults to None.
85
104
 
86
- Returns:
87
- --------
105
+ Returns
106
+ -------
88
107
  NDArray
89
108
  The frequency grid.
90
109
  """
@@ -230,7 +249,9 @@ def shift_fourier(
230
249
  def create_reconstruction_filter(
231
250
  filter_shape: Tuple[int], filter_type: str, **kwargs: Dict
232
251
  ):
233
- """Create a reconstruction filter of given filter_type.
252
+ """
253
+ Create a reconstruction filter of given filter_type. The DC component of
254
+ the filter will be located in the array center.
234
255
 
235
256
  Parameters
236
257
  ----------
@@ -298,7 +319,7 @@ def create_reconstruction_filter(
298
319
  ret = fftfreqn((size,), sampling_rate=1, compute_euclidean_norm=True)
299
320
  min_increment = np.radians(np.min(np.abs(np.diff(np.sort(tilt_angles)))))
300
321
  ret *= min_increment * size
301
- np.fmin(ret, 1, out=ret)
322
+ ret = np.fmin(ret, 1, out=ret)
302
323
  elif filter_type == "shepp-logan":
303
324
  ret = freq * np.sinc(freq / 2)
304
325
  elif filter_type == "cosine":
@@ -309,3 +330,8 @@ def create_reconstruction_filter(
309
330
  raise ValueError("Unsupported filter type")
310
331
 
311
332
  return ret
333
+
334
+
335
+ def pad_to_length(arr, length: int):
336
+ ret = np.atleast_1d(arr)
337
+ return np.repeat(ret, length // ret.size)