pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py → pytme-0.3b0.data/scripts/estimate_memory_usage.py +16 -33
  2. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/match_template.py +224 -223
  3. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/postprocess.py +283 -163
  4. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/preprocess.py +11 -8
  5. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/preprocessor_gui.py +10 -9
  6. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/METADATA +11 -9
  7. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/RECORD +61 -58
  8. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/entry_points.txt +1 -1
  9. scripts/{estimate_ram_usage.py → estimate_memory_usage.py} +16 -33
  10. scripts/extract_candidates.py +224 -0
  11. scripts/match_template.py +224 -223
  12. scripts/postprocess.py +283 -163
  13. scripts/preprocess.py +11 -8
  14. scripts/preprocessor_gui.py +10 -9
  15. scripts/refine_matches.py +626 -0
  16. tests/preprocessing/test_frequency_filters.py +9 -4
  17. tests/test_analyzer.py +143 -138
  18. tests/test_matching_cli.py +85 -29
  19. tests/test_matching_exhaustive.py +1 -2
  20. tests/test_matching_optimization.py +4 -9
  21. tests/test_orientations.py +0 -1
  22. tme/__version__.py +1 -1
  23. tme/analyzer/__init__.py +2 -0
  24. tme/analyzer/_utils.py +25 -17
  25. tme/analyzer/aggregation.py +385 -220
  26. tme/analyzer/base.py +138 -0
  27. tme/analyzer/peaks.py +150 -88
  28. tme/analyzer/proxy.py +122 -0
  29. tme/backends/__init__.py +4 -3
  30. tme/backends/_cupy_utils.py +25 -24
  31. tme/backends/_jax_utils.py +4 -3
  32. tme/backends/cupy_backend.py +4 -13
  33. tme/backends/jax_backend.py +6 -8
  34. tme/backends/matching_backend.py +4 -3
  35. tme/backends/mlx_backend.py +4 -3
  36. tme/backends/npfftw_backend.py +7 -5
  37. tme/backends/pytorch_backend.py +14 -4
  38. tme/cli.py +126 -0
  39. tme/density.py +4 -3
  40. tme/filters/__init__.py +1 -1
  41. tme/filters/_utils.py +4 -3
  42. tme/filters/bandpass.py +6 -4
  43. tme/filters/compose.py +5 -4
  44. tme/filters/ctf.py +426 -214
  45. tme/filters/reconstruction.py +58 -28
  46. tme/filters/wedge.py +139 -61
  47. tme/filters/whitening.py +36 -36
  48. tme/matching_data.py +4 -3
  49. tme/matching_exhaustive.py +17 -16
  50. tme/matching_optimization.py +5 -4
  51. tme/matching_scores.py +4 -3
  52. tme/matching_utils.py +6 -4
  53. tme/memory.py +4 -3
  54. tme/orientations.py +9 -6
  55. tme/parser.py +5 -4
  56. tme/preprocessor.py +4 -3
  57. tme/rotations.py +10 -7
  58. tme/structure.py +4 -3
  59. tests/data/Maps/.DS_Store +0 -0
  60. tests/data/Structures/.DS_Store +0 -0
  61. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/WHEEL +0 -0
  62. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/licenses/LICENSE +0 -0
  63. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/top_level.txt +0 -0
@@ -68,7 +68,6 @@ class TestBandPassFilter:
68
68
 
69
69
  assert isinstance(result, dict)
70
70
  assert "data" in result
71
- assert "sampling_rate" in result
72
71
  assert "is_multiplicative_filter" in result
73
72
  assert isinstance(result["data"], type(be.ones((1,))))
74
73
  assert result["is_multiplicative_filter"] is True
@@ -146,7 +145,11 @@ class TestLinearWhiteningFilter:
146
145
  ):
147
146
  data = be.random.random(shape)
148
147
  result = LinearWhiteningFilter()(
149
- data=data, n_bins=n_bins, batch_dimension=batch_dimension, order=order
148
+ shape=shape,
149
+ data=data,
150
+ n_bins=n_bins,
151
+ batch_dimension=batch_dimension,
152
+ order=order,
150
153
  )
151
154
 
152
155
  assert isinstance(result, dict)
@@ -161,7 +164,9 @@ class TestLinearWhiteningFilter:
161
164
  def test_call_method_with_data_rfft(self):
162
165
  shape = (30, 30, 30)
163
166
  data_rfft = be.fft.rfftn(be.random.random(shape))
164
- result = LinearWhiteningFilter()(data_rfft=data_rfft)
167
+ result = LinearWhiteningFilter()(
168
+ shape=shape, data_rfft=data_rfft, return_real_fourier=True
169
+ )
165
170
 
166
171
  assert isinstance(result, dict)
167
172
  assert result.get("data", False) is not False
@@ -172,7 +177,7 @@ class TestLinearWhiteningFilter:
172
177
  @pytest.mark.parametrize("shape", [(10, 10), (20, 20, 20), (30, 30, 30)])
173
178
  def test_filter_mask_range(self, shape: Tuple[int]):
174
179
  data = be.random.random(shape)
175
- result = LinearWhiteningFilter()(data=data)
180
+ result = LinearWhiteningFilter()(shape=shape, data=data)
176
181
 
177
182
  filter_mask = result["data"]
178
183
  assert np.all(filter_mask >= 0) and np.all(filter_mask <= 1)
tests/test_analyzer.py CHANGED
@@ -1,5 +1,3 @@
1
- from tempfile import mkstemp
2
-
3
1
  import pytest
4
2
  import numpy as np
5
3
 
@@ -24,6 +22,7 @@ PEAK_CALLER_CHILDREN = [
24
22
  PeakCallerScipy,
25
23
  PeakClustering,
26
24
  ]
25
+ np.random.seed(123)
27
26
 
28
27
 
29
28
  class TestPeakCallers:
@@ -54,21 +53,23 @@ class TestPeakCallers:
54
53
  @pytest.mark.parametrize("num_peaks", (1, 100))
55
54
  @pytest.mark.parametrize("minimum_score", (None, 0.5))
56
55
  def test__call__(self, peak_caller, num_peaks, minimum_score):
57
- peak_caller = peak_caller(
58
- shape=self.data.shape,
59
- num_peaks=num_peaks,
60
- min_distance=self.min_distance,
61
- min_score=minimum_score,
62
- )
63
- peak_caller(
56
+ kwargs = {
57
+ "shape": self.data.shape,
58
+ "num_peaks": num_peaks,
59
+ "min_distance": self.min_distance,
60
+ "min_score": minimum_score,
61
+ }
62
+ peak_caller = peak_caller(**kwargs)
63
+ state = peak_caller(
64
+ peak_caller.init_state(),
64
65
  self.data.copy(),
65
66
  rotation_matrix=self.rotation_matrix,
66
67
  )
67
- candidates = tuple(peak_caller)
68
+ state = peak_caller.result(state)
68
69
  if minimum_score is None:
69
- assert len(candidates[0] <= num_peaks)
70
+ assert len(state[0] <= num_peaks)
70
71
  else:
71
- peaks = candidates[0].astype(int)
72
+ peaks = state[0].astype(int)
72
73
  print(self.data[tuple(peaks.T)])
73
74
  assert np.all(self.data[tuple(peaks.T)] >= minimum_score)
74
75
 
@@ -78,139 +79,143 @@ class TestPeakCallers:
78
79
  peak_caller1 = peak_caller(
79
80
  shape=self.data.shape, num_peaks=num_peaks, min_distance=self.min_distance
80
81
  )
81
- peak_caller1(self.data, rotation_matrix=self.rotation_matrix)
82
+ state1 = peak_caller1.init_state()
83
+ state1 = peak_caller1(state1, self.data, rotation_matrix=self.rotation_matrix)
82
84
 
83
85
  peak_caller2 = peak_caller(
84
86
  shape=self.data.shape, num_peaks=num_peaks, min_distance=self.min_distance
85
87
  )
86
- peak_caller2(self.data, rotation_matrix=self.rotation_matrix)
87
-
88
- parameters = [tuple(peak_caller1), tuple(peak_caller2)]
88
+ state2 = peak_caller2.init_state()
89
+ state2 = peak_caller2(state2, self.data, rotation_matrix=self.rotation_matrix)
89
90
 
90
- result = tuple(
91
- peak_caller.merge(
92
- candidates=parameters,
93
- num_peaks=num_peaks,
94
- min_distance=self.min_distance,
95
- )
96
- )
97
- assert [len(res) == 2 for res in result]
98
-
99
-
100
- class TestRecursiveMasking:
101
- def setup_method(self):
102
- self.num_peaks = 100
103
- self.min_distance = 5
104
- self.data = np.random.rand(100, 100, 100)
105
- self.rotation_matrix = np.eye(3)
106
- self.mask = np.random.rand(20, 20, 20)
107
- self.rotation_space = np.zeros_like(self.data)
108
- self.rotation_mapping = {0: (0, 0, 0)}
109
-
110
- @pytest.mark.parametrize("num_peaks", (1, 100))
111
- @pytest.mark.parametrize("compute_rotation", (True, False))
112
- @pytest.mark.parametrize("minimum_score", (None, 0.5))
113
- def test__call__(self, num_peaks, compute_rotation, minimum_score):
114
- peak_caller = PeakCallerRecursiveMasking(
115
- shape=self.data.shape,
91
+ states = [peak_caller1.result(state1), peak_caller2.result(state2)]
92
+ result = peak_caller.merge(
93
+ results=states,
116
94
  num_peaks=num_peaks,
117
95
  min_distance=self.min_distance,
118
- min_score=minimum_score,
119
- )
120
- rotation_space, rotation_mapping = None, None
121
- if compute_rotation:
122
- rotation_space = self.rotation_space
123
- rotation_mapping = self.rotation_mapping
124
-
125
- peak_caller(
126
- self.data.copy(),
127
- rotation_matrix=self.rotation_matrix,
128
- mask=self.mask,
129
- rotation_space=rotation_space,
130
- rotation_mapping=rotation_mapping,
131
96
  )
97
+ assert [len(res) == 2 for res in result]
132
98
 
133
- candidates = tuple(peak_caller)
134
- if minimum_score is None:
135
- assert len(candidates[0] <= num_peaks)
136
- else:
137
- peaks = candidates[0].astype(int)
138
- assert np.all(self.data[tuple(peaks.T)] >= minimum_score)
139
-
140
-
141
- class TestMaxScoreOverRotations:
142
- def setup_method(self):
143
- self.num_peaks = 100
144
- self.min_distance = 5
145
- self.data = np.random.rand(100, 100, 100)
146
- self.rotation_matrix = np.eye(3)
147
-
148
- def test_initialization(self):
149
- _ = MaxScoreOverRotations(
150
- shape=self.data.shape,
151
- translation_offset=np.zeros(self.data.ndim, dtype=int),
152
- )
153
-
154
- @pytest.mark.parametrize("use_memmap", [False, True])
155
- def test__iter__(self, use_memmap: bool):
156
- score_analyzer = MaxScoreOverRotations(
157
- shape=self.data.shape,
158
- use_memmap=use_memmap,
159
- )
160
- score_analyzer(self.data, rotation_matrix=self.rotation_matrix)
161
- res = tuple(score_analyzer)
162
- assert np.allclose(res[0].shape, self.data.shape)
163
- assert res[0].dtype == be._float_dtype
164
- assert res[1].size == self.data.ndim
165
- assert np.allclose(res[2].shape, self.data.shape)
166
- assert len(res) == 4
167
-
168
- @pytest.mark.parametrize("use_memmap", [False, True])
169
- @pytest.mark.parametrize("score_threshold", [0, 1e10, -1e10])
170
- def test__call__(self, use_memmap: bool, score_threshold: float):
171
- score_analyzer = MaxScoreOverRotations(
172
- shape=self.data.shape,
173
- score_threshold=score_threshold,
174
- translation_offset=np.zeros(self.data.ndim, dtype=int),
175
- use_memmap=use_memmap,
176
- )
177
- score_analyzer(self.data, rotation_matrix=self.rotation_matrix)
178
-
179
- data2 = self.data * 2
180
- score_analyzer(data2, rotation_matrix=self.rotation_matrix)
181
- scores, translation_offset, rotations, mapping = tuple(score_analyzer)
182
- assert np.all(scores >= score_threshold)
183
- max_scores = np.maximum(self.data, data2)
184
- max_scores = np.maximum(max_scores, score_threshold)
185
- assert np.allclose(scores, max_scores)
186
-
187
- @pytest.mark.parametrize("use_memmap", [False, True])
188
- @pytest.mark.parametrize("score_threshold", [0, 1e10, -1e10])
189
- def test_merge(self, use_memmap: bool, score_threshold: float):
190
- score_analyzer = MaxScoreOverRotations(
191
- shape=self.data.shape,
192
- score_threshold=score_threshold,
193
- translation_offset=np.zeros(self.data.ndim, dtype=int),
194
- use_memmap=use_memmap,
195
- )
196
- score_analyzer(self.data, rotation_matrix=self.rotation_matrix)
197
-
198
- data2 = self.data * 2
199
- score_analyzer2 = MaxScoreOverRotations(
200
- shape=self.data.shape,
201
- score_threshold=score_threshold,
202
- translation_offset=np.zeros(self.data.ndim, dtype=int),
203
- use_memmap=use_memmap,
204
- )
205
- score_analyzer2(data2, rotation_matrix=self.rotation_matrix)
206
-
207
- parameters = [tuple(score_analyzer), tuple(score_analyzer2)]
208
99
 
209
- ret = MaxScoreOverRotations.merge(
210
- parameters, use_memmap=use_memmap, score_threshold=score_threshold
211
- )
212
- scores, translation, rotations, mapping = ret
213
- assert np.all(scores >= score_threshold)
214
- max_scores = np.maximum(self.data, data2)
215
- max_scores = np.maximum(max_scores, score_threshold)
216
- assert np.allclose(scores, max_scores)
100
+ # class TestRecursiveMasking:
101
+ # def setup_method(self):
102
+ # self.num_peaks = 100
103
+ # self.min_distance = 5
104
+ # self.data = np.random.rand(100, 100, 100)
105
+ # self.rotation_matrix = np.eye(3)
106
+ # self.mask = np.random.rand(20, 20, 20)
107
+ # self.rotation_space = np.zeros_like(self.data)
108
+ # self.rotation_mapping = {0: (0, 0, 0)}
109
+
110
+ # @pytest.mark.parametrize("num_peaks", (1, 100))
111
+ # @pytest.mark.parametrize("compute_rotation", (True, False))
112
+ # @pytest.mark.parametrize("minimum_score", (None, 0.5))
113
+ # def test__call__(self, num_peaks, compute_rotation, minimum_score):
114
+ # peak_caller = PeakCallerRecursiveMasking(
115
+ # shape=self.data.shape,
116
+ # num_peaks=num_peaks,
117
+ # min_distance=self.min_distance,
118
+ # min_score=minimum_score,
119
+ # )
120
+ # rotation_space, rotation_mapping = None, None
121
+ # if compute_rotation:
122
+ # rotation_space = self.rotation_space
123
+ # rotation_mapping = self.rotation_mapping
124
+
125
+ # state = peak_caller.init_state()
126
+ # state = peak_caller(
127
+ # state,
128
+ # self.data.copy(),
129
+ # rotation_matrix=self.rotation_matrix,
130
+ # mask=self.mask,
131
+ # rotation_space=rotation_space,
132
+ # rotation_mapping=rotation_mapping,
133
+ # )
134
+
135
+ # if minimum_score is None:
136
+ # assert len(state[0] <= num_peaks)
137
+ # else:
138
+ # peaks = state[0].astype(int)
139
+ # assert np.all(self.data[tuple(peaks.T)] >= minimum_score)
140
+
141
+
142
+ # class TestMaxScoreOverRotations:
143
+ # def setup_method(self):
144
+ # self.num_peaks = 100
145
+ # self.min_distance = 5
146
+ # self.data = np.random.rand(100, 100, 100)
147
+ # self.rotation_matrix = np.eye(3)
148
+
149
+ # def test_initialization(self):
150
+ # _ = MaxScoreOverRotations(
151
+ # shape=self.data.shape,
152
+ # translation_offset=np.zeros(self.data.ndim, dtype=int),
153
+ # )
154
+
155
+ # @pytest.mark.parametrize("use_memmap", [False, True])
156
+ # def test__iter__(self, use_memmap: bool):
157
+ # score_analyzer = MaxScoreOverRotations(
158
+ # shape=self.data.shape,
159
+ # use_memmap=use_memmap,
160
+ # )
161
+ # state = score_analyzer.init_state()
162
+ # state = score_analyzer(state, self.data, rotation_matrix=self.rotation_matrix)
163
+ # res = score_analyzer.result(state)
164
+ # assert np.allclose(res[0].shape, self.data.shape)
165
+ # assert res[0].dtype == be._float_dtype
166
+ # assert res[1].size == self.data.ndim
167
+ # assert np.allclose(res[2].shape, self.data.shape)
168
+ # assert len(res) == 4
169
+
170
+ # @pytest.mark.parametrize("use_memmap", [False, True])
171
+ # @pytest.mark.parametrize("score_threshold", [0, 1e10, -1e10])
172
+ # def test__call__(self, use_memmap: bool, score_threshold: float):
173
+ # score_analyzer = MaxScoreOverRotations(
174
+ # shape=self.data.shape,
175
+ # score_threshold=score_threshold,
176
+ # translation_offset=np.zeros(self.data.ndim, dtype=int),
177
+ # use_memmap=use_memmap,
178
+ # )
179
+ # state = score_analyzer.init_state()
180
+ # state = score_analyzer(state, self.data, rotation_matrix=self.rotation_matrix)
181
+
182
+ # data2 = self.data * 2
183
+ # score_analyzer(state, data2, rotation_matrix=self.rotation_matrix)
184
+ # scores, translation_offset, rotations, mapping = score_analyzer.result(state)
185
+
186
+ # assert np.all(scores >= score_threshold)
187
+ # max_scores = np.maximum(self.data, data2)
188
+ # max_scores = np.maximum(max_scores, score_threshold)
189
+ # assert np.allclose(scores, max_scores)
190
+
191
+ # @pytest.mark.parametrize("use_memmap", [False, True])
192
+ # @pytest.mark.parametrize("score_threshold", [0, 1e10, -1e10])
193
+ # def test_merge(self, use_memmap: bool, score_threshold: float):
194
+ # score_analyzer = MaxScoreOverRotations(
195
+ # shape=self.data.shape,
196
+ # score_threshold=score_threshold,
197
+ # translation_offset=np.zeros(self.data.ndim, dtype=int),
198
+ # use_memmap=use_memmap,
199
+ # )
200
+ # state1 = score_analyzer.init_state()
201
+ # state1, score_analyzer(state1, self.data, rotation_matrix=self.rotation_matrix)
202
+
203
+ # data2 = self.data * 2
204
+ # score_analyzer2 = MaxScoreOverRotations(
205
+ # shape=self.data.shape,
206
+ # score_threshold=score_threshold,
207
+ # translation_offset=np.zeros(self.data.ndim, dtype=int),
208
+ # use_memmap=use_memmap,
209
+ # )
210
+ # state2 = score_analyzer2.init_state()
211
+ # state2 = score_analyzer2(state2, data2, rotation_matrix=self.rotation_matrix)
212
+ # states = [score_analyzer.result(state1), score_analyzer2.result(state2)]
213
+
214
+ # ret = MaxScoreOverRotations.merge(
215
+ # results=states, use_memmap=use_memmap, score_threshold=score_threshold
216
+ # )
217
+ # scores, translation, rotations, mapping = ret
218
+ # assert np.all(scores >= score_threshold)
219
+ # max_scores = np.maximum(self.data, data2)
220
+ # max_scores = np.maximum(max_scores, score_threshold)
221
+ # assert np.allclose(scores, max_scores)
@@ -6,26 +6,10 @@ from os import remove, makedirs
6
6
 
7
7
  import pytest
8
8
  import numpy as np
9
- from tme import Density
9
+ from tme import Density, Orientations
10
10
  from tme.backends import backend as be
11
11
 
12
- BACKEND_CLASSES = ["NumpyFFTWBackend", "PytorchBackend", "CupyBackend", "MLXBackend"]
13
- BACKENDS_TO_TEST = []
14
-
15
- test_gpu = (False,)
16
- for backend_class in BACKEND_CLASSES:
17
- try:
18
- BackendClass = getattr(
19
- __import__("tme.backends", fromlist=[backend_class]), backend_class
20
- )
21
- BACKENDS_TO_TEST.append(BackendClass(device="cpu"))
22
- if backend_class == "CupyBackend":
23
- if BACKENDS_TO_TEST[-1].device_count() >= 1:
24
- test_gpu = (False, True)
25
- except ImportError:
26
- print(f"Couldn't import {backend_class}. Skipping...")
27
-
28
-
12
+ np.random.seed(42)
29
13
  available_backends = (x for x in be.available_backends() if x != "mlx")
30
14
 
31
15
 
@@ -48,7 +32,6 @@ def argdict_to_command(input_args, executable: str):
48
32
  class TestMatchTemplate:
49
33
  @classmethod
50
34
  def setup_class(cls):
51
- np.random.seed(42)
52
35
  target = np.random.rand(20, 20, 20)
53
36
  template = np.random.rand(5, 5, 5)
54
37
 
@@ -65,6 +48,19 @@ class TestMatchTemplate:
65
48
  cls.template_mask_path = tempfile.NamedTemporaryFile(
66
49
  delete=False, suffix=".mrc"
67
50
  ).name
51
+ cls.tempdir = tempfile.TemporaryDirectory().name
52
+ makedirs(cls.tempdir, exist_ok=True)
53
+
54
+ orientations = Orientations(
55
+ translations=((10, 10, 10), (12, 10, 15)),
56
+ rotations=((0, 0, 0), (45, 12, 90)),
57
+ scores=(0, 0),
58
+ details=(-1, -1),
59
+ )
60
+ cls.orientations_path = tempfile.NamedTemporaryFile(
61
+ delete=False, suffix=".star"
62
+ ).name
63
+ orientations.to_file(cls.orientations_path)
68
64
 
69
65
  Density(target, sampling_rate=5).to_file(cls.target_path)
70
66
  Density(template, sampling_rate=5).to_file(cls.template_path)
@@ -76,6 +72,8 @@ class TestMatchTemplate:
76
72
  cls.try_delete(cls.template_path)
77
73
  cls.try_delete(cls.target_mask_path)
78
74
  cls.try_delete(cls.template_mask_path)
75
+ cls.try_delete(cls.orientations_path)
76
+ cls.try_delete(cls.tempdir)
79
77
 
80
78
  @staticmethod
81
79
  def try_delete(file_path: str):
@@ -88,8 +86,8 @@ class TestMatchTemplate:
88
86
  except Exception:
89
87
  pass
90
88
 
91
- @staticmethod
92
89
  def run_matching(
90
+ self,
93
91
  use_template_mask: bool,
94
92
  test_filter: bool,
95
93
  call_peaks: bool,
@@ -99,6 +97,7 @@ class TestMatchTemplate:
99
97
  target_mask_path: str,
100
98
  use_target_mask: bool = False,
101
99
  backend: str = "numpyfftw",
100
+ test_rejection_sampling: bool = False,
102
101
  ):
103
102
  output_path = tempfile.NamedTemporaryFile(delete=False, suffix="pickle").name
104
103
 
@@ -109,7 +108,6 @@ class TestMatchTemplate:
109
108
  "-a": 60,
110
109
  "-o": output_path,
111
110
  "--pad_edges": False,
112
- "--pad_fourier": False,
113
111
  "--backend": backend,
114
112
  }
115
113
 
@@ -119,14 +117,14 @@ class TestMatchTemplate:
119
117
  if use_target_mask:
120
118
  argdict["--target_mask"] = target_mask_path
121
119
 
122
- if backend in ("cupy", "pytorch") and True in test_gpu:
123
- argdict["--use_gpu"] = True
120
+ if test_rejection_sampling:
121
+ argdict["--orientations"] = self.orientations_path
124
122
 
125
123
  if test_filter:
126
124
  argdict["--lowpass"] = 30
127
125
  argdict["--defocus"] = 3000
128
- argdict["--tilt_angles"] = "40,40:10"
129
- argdict["--wedge_axes"] = "0,2"
126
+ argdict["--tilt_angles"] = "40,40"
127
+ argdict["--wedge_axes"] = "2,0"
130
128
  argdict["--whiten"] = True
131
129
 
132
130
  if call_peaks:
@@ -142,13 +140,18 @@ class TestMatchTemplate:
142
140
  @pytest.mark.parametrize("call_peaks", (False, True))
143
141
  @pytest.mark.parametrize("use_template_mask", (False, True))
144
142
  @pytest.mark.parametrize("test_filter", (False, True))
143
+ @pytest.mark.parametrize("test_rejection_sampling", (False, True))
145
144
  def test_match_template(
146
145
  self,
147
146
  backend: bool,
148
147
  call_peaks: bool,
149
148
  use_template_mask: bool,
150
149
  test_filter: bool,
150
+ test_rejection_sampling: bool,
151
151
  ):
152
+ if backend == "jax" and (call_peaks or test_rejection_sampling):
153
+ return None
154
+
152
155
  self.run_matching(
153
156
  use_template_mask=use_template_mask,
154
157
  use_target_mask=True,
@@ -159,6 +162,7 @@ class TestMatchTemplate:
159
162
  target_path=self.target_path,
160
163
  template_mask_path=self.template_mask_path,
161
164
  target_mask_path=self.target_mask_path,
165
+ test_rejection_sampling=test_rejection_sampling,
162
166
  )
163
167
 
164
168
 
@@ -175,20 +179,20 @@ class TestPostprocessing(TestMatchTemplate):
175
179
  "target_path": cls.target_path,
176
180
  "template_mask_path": cls.template_mask_path,
177
181
  "target_mask_path": cls.target_mask_path,
182
+ "test_rejection_sampling": False,
178
183
  }
179
184
 
180
185
  cls.score_pickle = cls.run_matching(
186
+ cls,
181
187
  call_peaks=False,
182
188
  **matching_kwargs,
183
189
  )
184
- cls.peak_pickle = cls.run_matching(call_peaks=True, **matching_kwargs)
185
- cls.tempdir = tempfile.TemporaryDirectory().name
190
+ cls.peak_pickle = cls.run_matching(cls, call_peaks=True, **matching_kwargs)
186
191
 
187
192
  @classmethod
188
193
  def teardown_class(cls):
189
194
  cls.try_delete(cls.score_pickle)
190
195
  cls.try_delete(cls.peak_pickle)
191
- cls.try_delete(cls.tempdir)
192
196
 
193
197
  @pytest.mark.parametrize("distance_cutoff_strategy", (0, 1, 2, 3))
194
198
  @pytest.mark.parametrize("score_cutoff", (None, (1,), (0, 1), (None, 1), (0, None)))
@@ -252,6 +256,7 @@ class TestPostprocessing(TestMatchTemplate):
252
256
  }
253
257
  cmd = argdict_to_command(argdict, executable="postprocess.py")
254
258
  ret = subprocess.run(cmd, capture_output=True, shell=True)
259
+ print(ret)
255
260
 
256
261
  match output_format:
257
262
  case "orientations":
@@ -264,7 +269,8 @@ class TestPostprocessing(TestMatchTemplate):
264
269
  assert exists(f"{self.tempdir}/temp.star")
265
270
  case "relion5":
266
271
  assert exists(f"{self.tempdir}/temp.star")
267
-
272
+ case "pickle":
273
+ assert exists(f"{self.tempdir}/temp.pickle")
268
274
  assert ret.returncode == 0
269
275
 
270
276
  def test_postprocess_score_local_optimization(self):
@@ -281,3 +287,53 @@ class TestPostprocessing(TestMatchTemplate):
281
287
  cmd = argdict_to_command(argdict, executable="postprocess.py")
282
288
  ret = subprocess.run(cmd, capture_output=True, shell=True)
283
289
  assert ret.returncode == 0
290
+
291
+
292
+ class TestEstimateMemoryUsage(TestMatchTemplate):
293
+ @classmethod
294
+ def setup_class(cls):
295
+ super().setup_class()
296
+
297
+ @pytest.mark.parametrize("ncores", (1, 4, 8))
298
+ @pytest.mark.parametrize("pad_edges", (False, True))
299
+ def test_estimation(self, ncores, pad_edges):
300
+
301
+ argdict = {
302
+ "-m": self.target_path,
303
+ "-i": self.template_path,
304
+ "--ncores": ncores,
305
+ "--pad_edges": pad_edges,
306
+ "--score": "FLCSphericalMask",
307
+ }
308
+
309
+ cmd = argdict_to_command(argdict, executable="estimate_memory_usage.py")
310
+ ret = subprocess.run(cmd, capture_output=True, shell=True)
311
+ assert ret.returncode == 0
312
+
313
+
314
+ class TestPreprocess(TestMatchTemplate):
315
+ @classmethod
316
+ def setup_class(cls):
317
+ super().setup_class()
318
+
319
+ @pytest.mark.parametrize("backend", available_backends)
320
+ @pytest.mark.parametrize("align_axis", (False, True))
321
+ @pytest.mark.parametrize("invert_contrast", (False, True))
322
+ def test_estimation(self, backend, align_axis, invert_contrast):
323
+
324
+ argdict = {
325
+ "-m": self.target_path,
326
+ "--backend": backend,
327
+ "--lowpass": 40,
328
+ "--sampling_rate": 5,
329
+ "-o": f"{self.tempdir}/out.mrc",
330
+ }
331
+ if align_axis:
332
+ argdict["--align_axis"] = 2
333
+
334
+ if invert_contrast:
335
+ argdict["--invert_contrast"] = True
336
+
337
+ cmd = argdict_to_command(argdict, executable="preprocess.py")
338
+ ret = subprocess.run(cmd, capture_output=True, shell=True)
339
+ assert ret.returncode == 0
@@ -7,7 +7,6 @@ from tme.matching_data import MatchingData
7
7
  from tme.memory import MATCHING_MEMORY_REGISTRY
8
8
  from tme.analyzer import MaxScoreOverRotations, PeakCallerSort
9
9
  from tme.matching_exhaustive import (
10
- scan,
11
10
  scan_subsets,
12
11
  MATCHING_EXHAUSTIVE_REGISTER,
13
12
  register_matching_exhaustive,
@@ -36,7 +35,7 @@ class TestMatchExhaustive:
36
35
  self.coordinates_weights = None
37
36
  self.rotations = None
38
37
 
39
- @pytest.mark.parametrize("evaluate_peak", (False, True))
38
+ @pytest.mark.parametrize("evaluate_peak", (True,))
40
39
  @pytest.mark.parametrize("score", tuple(MATCHING_EXHAUSTIVE_REGISTER.keys()))
41
40
  @pytest.mark.parametrize("job_schedule", ((2, 1),))
42
41
  @pytest.mark.parametrize("pad_edge", (False, True))
@@ -24,6 +24,7 @@ coordinate_to_coordinate = [
24
24
  for k, v in MATCHING_OPTIMIZATION_REGISTER.items()
25
25
  if issubclass(v, _MatchCoordinatesToCoordinates)
26
26
  ]
27
+ np.random.seed(42)
27
28
 
28
29
 
29
30
  class TestMatchDensityToDensity:
@@ -52,9 +53,7 @@ class TestMatchDensityToDensity:
52
53
 
53
54
  @pytest.mark.parametrize("method", density_to_density)
54
55
  def test_call(self, method):
55
- instance = self.test_initialization(method=method, notest=True)
56
- score = instance()
57
- assert isinstance(score, float)
56
+ self.test_initialization(method=method, notest=True)()
58
57
 
59
58
 
60
59
  class TestMatchDensityToCoordinates:
@@ -97,9 +96,7 @@ class TestMatchDensityToCoordinates:
97
96
 
98
97
  @pytest.mark.parametrize("method", coordinate_to_density)
99
98
  def test_call(self, method):
100
- instance = self.test_initialization(method=method, notest=True)
101
- score = instance()
102
- assert isinstance(score, float)
99
+ self.test_initialization(method=method, notest=True)()
103
100
 
104
101
 
105
102
  class TestMatchCoordinateToCoordinates:
@@ -135,9 +132,7 @@ class TestMatchCoordinateToCoordinates:
135
132
 
136
133
  @pytest.mark.parametrize("method", coordinate_to_coordinate)
137
134
  def test_call(self, method):
138
- instance = self.test_initialization(method=method, notest=True)
139
- score = instance()
140
- assert isinstance(score, float)
135
+ self.test_initialization(method=method, notest=True)()
141
136
 
142
137
 
143
138
  class TestOptimizeMatch:
@@ -50,7 +50,6 @@ class TestDensity:
50
50
  assert np.issubdtype(orientations.translations.dtype, np.floating)
51
51
  assert np.issubdtype(orientations.rotations.dtype, np.floating)
52
52
  assert np.issubdtype(orientations.scores.dtype, np.floating)
53
- assert np.issubdtype(orientations.details.dtype, np.floating)
54
53
 
55
54
  def test_initialization_error(self):
56
55
  with pytest.raises(ValueError):
tme/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.2.9"
1
+ __version__ = "0.3.b0"
tme/analyzer/__init__.py CHANGED
@@ -1,2 +1,4 @@
1
1
  from .peaks import *
2
2
  from .aggregation import *
3
+ from .proxy import *
4
+ from .base import *