pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3.0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. pytme-0.3.0.data/scripts/estimate_memory_usage.py +76 -0
  2. pytme-0.3.0.data/scripts/match_template.py +1106 -0
  3. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/postprocess.py +320 -190
  4. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocess.py +21 -31
  5. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocessor_gui.py +85 -19
  6. pytme-0.3.0.data/scripts/pytme_runner.py +771 -0
  7. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/METADATA +22 -20
  8. pytme-0.3.0.dist-info/RECORD +126 -0
  9. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/entry_points.txt +2 -1
  10. pytme-0.3.0.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +76 -0
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +224 -0
  14. scripts/match_template.py +349 -378
  15. pytme-0.2.9.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
  16. scripts/postprocess.py +320 -190
  17. scripts/preprocess.py +21 -31
  18. scripts/preprocessor_gui.py +85 -19
  19. scripts/pytme_runner.py +771 -0
  20. scripts/refine_matches.py +625 -0
  21. tests/preprocessing/test_frequency_filters.py +28 -14
  22. tests/test_analyzer.py +41 -36
  23. tests/test_backends.py +1 -0
  24. tests/test_matching_cli.py +109 -53
  25. tests/test_matching_data.py +5 -5
  26. tests/test_matching_exhaustive.py +1 -2
  27. tests/test_matching_optimization.py +4 -9
  28. tests/test_matching_utils.py +1 -1
  29. tests/test_orientations.py +0 -1
  30. tme/__version__.py +1 -1
  31. tme/analyzer/__init__.py +2 -0
  32. tme/analyzer/_utils.py +26 -21
  33. tme/analyzer/aggregation.py +396 -222
  34. tme/analyzer/base.py +127 -0
  35. tme/analyzer/peaks.py +189 -201
  36. tme/analyzer/proxy.py +123 -0
  37. tme/backends/__init__.py +4 -3
  38. tme/backends/_cupy_utils.py +25 -24
  39. tme/backends/_jax_utils.py +20 -18
  40. tme/backends/cupy_backend.py +13 -26
  41. tme/backends/jax_backend.py +24 -23
  42. tme/backends/matching_backend.py +4 -3
  43. tme/backends/mlx_backend.py +4 -3
  44. tme/backends/npfftw_backend.py +34 -30
  45. tme/backends/pytorch_backend.py +18 -4
  46. tme/cli.py +126 -0
  47. tme/density.py +9 -7
  48. tme/extensions.cpython-311-darwin.so +0 -0
  49. tme/filters/__init__.py +3 -3
  50. tme/filters/_utils.py +36 -10
  51. tme/filters/bandpass.py +229 -188
  52. tme/filters/compose.py +5 -4
  53. tme/filters/ctf.py +516 -254
  54. tme/filters/reconstruction.py +91 -32
  55. tme/filters/wedge.py +196 -135
  56. tme/filters/whitening.py +37 -42
  57. tme/matching_data.py +28 -39
  58. tme/matching_exhaustive.py +31 -27
  59. tme/matching_optimization.py +5 -4
  60. tme/matching_scores.py +25 -15
  61. tme/matching_utils.py +158 -28
  62. tme/memory.py +4 -3
  63. tme/orientations.py +22 -9
  64. tme/parser.py +114 -33
  65. tme/preprocessor.py +6 -5
  66. tme/rotations.py +10 -7
  67. tme/structure.py +4 -3
  68. pytme-0.2.9.data/scripts/estimate_ram_usage.py +0 -97
  69. pytme-0.2.9.dist-info/RECORD +0 -119
  70. pytme-0.2.9.dist-info/licenses/LICENSE +0 -153
  71. scripts/estimate_ram_usage.py +0 -97
  72. tests/data/Maps/.DS_Store +0 -0
  73. tests/data/Structures/.DS_Store +0 -0
  74. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/WHEEL +0 -0
  75. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/top_level.txt +0 -0
@@ -4,13 +4,15 @@ from typing import Tuple
4
4
 
5
5
  from tme.backends import backend as be
6
6
  from tme.filters._utils import compute_fourier_shape
7
- from tme.filters import BandPassFilter, LinearWhiteningFilter
7
+ from tme.filters import BandPassReconstructed, LinearWhiteningFilter
8
+ from tme.filters.bandpass import gaussian_bandpass, discrete_bandpass
9
+ from tme.filters._utils import fftfreqn
8
10
 
9
11
 
10
12
  class TestBandPassFilter:
11
13
  @pytest.fixture
12
14
  def band_pass_filter(self):
13
- return BandPassFilter()
15
+ return BandPassReconstructed()
14
16
 
15
17
  @pytest.mark.parametrize(
16
18
  "shape, lowpass, highpass, sampling_rate",
@@ -24,9 +26,13 @@ class TestBandPassFilter:
24
26
  def test_discrete_bandpass(
25
27
  self, shape: Tuple[int], lowpass: float, highpass: float, sampling_rate: float
26
28
  ):
27
- result = BandPassFilter.discrete_bandpass(
28
- shape, lowpass, highpass, sampling_rate
29
+ grid = fftfreqn(
30
+ shape=shape,
31
+ sampling_rate=0.5,
32
+ shape_is_real_fourier=False,
33
+ compute_euclidean_norm=True,
29
34
  )
35
+ result = discrete_bandpass(grid, lowpass, highpass, sampling_rate)
30
36
  assert isinstance(result, type(be.ones((1,))))
31
37
  assert result.shape == shape
32
38
  assert np.all((result >= 0) & (result <= 1))
@@ -43,9 +49,13 @@ class TestBandPassFilter:
43
49
  def test_gaussian_bandpass(
44
50
  self, shape: Tuple[int], lowpass: float, highpass: float, sampling_rate: float
45
51
  ):
46
- result = BandPassFilter.gaussian_bandpass(
47
- shape, lowpass, highpass, sampling_rate
52
+ grid = fftfreqn(
53
+ shape=shape,
54
+ sampling_rate=0.5,
55
+ shape_is_real_fourier=False,
56
+ compute_euclidean_norm=True,
48
57
  )
58
+ result = gaussian_bandpass(grid, lowpass, highpass, sampling_rate)
49
59
  assert isinstance(result, type(be.ones((1,))))
50
60
  assert result.shape == shape
51
61
  assert np.all((result >= 0) & (result <= 1))
@@ -55,7 +65,7 @@ class TestBandPassFilter:
55
65
  @pytest.mark.parametrize("shape_is_real_fourier", [True, False])
56
66
  def test_call_method(
57
67
  self,
58
- band_pass_filter: BandPassFilter,
68
+ band_pass_filter: BandPassReconstructed,
59
69
  use_gaussian: bool,
60
70
  return_real_fourier: bool,
61
71
  shape_is_real_fourier: bool,
@@ -68,22 +78,20 @@ class TestBandPassFilter:
68
78
 
69
79
  assert isinstance(result, dict)
70
80
  assert "data" in result
71
- assert "sampling_rate" in result
72
81
  assert "is_multiplicative_filter" in result
73
82
  assert isinstance(result["data"], type(be.ones((1,))))
74
83
  assert result["is_multiplicative_filter"] is True
75
84
 
76
- def test_default_values(self, band_pass_filter: BandPassFilter):
85
+ def test_default_values(self, band_pass_filter: BandPassReconstructed):
77
86
  assert band_pass_filter.lowpass is None
78
87
  assert band_pass_filter.highpass is None
79
88
  assert band_pass_filter.sampling_rate == 1
80
89
  assert band_pass_filter.use_gaussian is True
81
90
  assert band_pass_filter.return_real_fourier is False
82
- assert band_pass_filter.shape_is_real_fourier is False
83
91
 
84
92
  @pytest.mark.parametrize("shape", ((10, 10), (20, 20, 20), (30, 30)))
85
93
  def test_return_real_fourier(self, shape: Tuple[int]):
86
- bpf = BandPassFilter(return_real_fourier=True)
94
+ bpf = BandPassReconstructed(return_real_fourier=True)
87
95
  result = bpf(shape=shape, lowpass=0.2, highpass=0.8)
88
96
  expected_shape = tuple(compute_fourier_shape(shape, False))
89
97
  assert result["data"].shape == expected_shape
@@ -146,7 +154,11 @@ class TestLinearWhiteningFilter:
146
154
  ):
147
155
  data = be.random.random(shape)
148
156
  result = LinearWhiteningFilter()(
149
- data=data, n_bins=n_bins, batch_dimension=batch_dimension, order=order
157
+ shape=shape,
158
+ data=data,
159
+ n_bins=n_bins,
160
+ batch_dimension=batch_dimension,
161
+ order=order,
150
162
  )
151
163
 
152
164
  assert isinstance(result, dict)
@@ -161,7 +173,9 @@ class TestLinearWhiteningFilter:
161
173
  def test_call_method_with_data_rfft(self):
162
174
  shape = (30, 30, 30)
163
175
  data_rfft = be.fft.rfftn(be.random.random(shape))
164
- result = LinearWhiteningFilter()(data_rfft=data_rfft)
176
+ result = LinearWhiteningFilter()(
177
+ shape=shape, data_rfft=data_rfft, return_real_fourier=True
178
+ )
165
179
 
166
180
  assert isinstance(result, dict)
167
181
  assert result.get("data", False) is not False
@@ -172,7 +186,7 @@ class TestLinearWhiteningFilter:
172
186
  @pytest.mark.parametrize("shape", [(10, 10), (20, 20, 20), (30, 30, 30)])
173
187
  def test_filter_mask_range(self, shape: Tuple[int]):
174
188
  data = be.random.random(shape)
175
- result = LinearWhiteningFilter()(data=data)
189
+ result = LinearWhiteningFilter()(shape=shape, data=data)
176
190
 
177
191
  filter_mask = result["data"]
178
192
  assert np.all(filter_mask >= 0) and np.all(filter_mask <= 1)
tests/test_analyzer.py CHANGED
@@ -1,5 +1,3 @@
1
- from tempfile import mkstemp
2
-
3
1
  import pytest
4
2
  import numpy as np
5
3
 
@@ -24,6 +22,7 @@ PEAK_CALLER_CHILDREN = [
24
22
  PeakCallerScipy,
25
23
  PeakClustering,
26
24
  ]
25
+ np.random.seed(123)
27
26
 
28
27
 
29
28
  class TestPeakCallers:
@@ -54,21 +53,23 @@ class TestPeakCallers:
54
53
  @pytest.mark.parametrize("num_peaks", (1, 100))
55
54
  @pytest.mark.parametrize("minimum_score", (None, 0.5))
56
55
  def test__call__(self, peak_caller, num_peaks, minimum_score):
57
- peak_caller = peak_caller(
58
- shape=self.data.shape,
59
- num_peaks=num_peaks,
60
- min_distance=self.min_distance,
61
- min_score=minimum_score,
62
- )
63
- peak_caller(
56
+ kwargs = {
57
+ "shape": self.data.shape,
58
+ "num_peaks": num_peaks,
59
+ "min_distance": self.min_distance,
60
+ "min_score": minimum_score,
61
+ }
62
+ peak_caller = peak_caller(**kwargs)
63
+ state = peak_caller(
64
+ peak_caller.init_state(),
64
65
  self.data.copy(),
65
66
  rotation_matrix=self.rotation_matrix,
66
67
  )
67
- candidates = tuple(peak_caller)
68
+ state = peak_caller.result(state)
68
69
  if minimum_score is None:
69
- assert len(candidates[0] <= num_peaks)
70
+ assert len(state[0] <= num_peaks)
70
71
  else:
71
- peaks = candidates[0].astype(int)
72
+ peaks = state[0].astype(int)
72
73
  print(self.data[tuple(peaks.T)])
73
74
  assert np.all(self.data[tuple(peaks.T)] >= minimum_score)
74
75
 
@@ -78,21 +79,20 @@ class TestPeakCallers:
78
79
  peak_caller1 = peak_caller(
79
80
  shape=self.data.shape, num_peaks=num_peaks, min_distance=self.min_distance
80
81
  )
81
- peak_caller1(self.data, rotation_matrix=self.rotation_matrix)
82
+ state1 = peak_caller1.init_state()
83
+ state1 = peak_caller1(state1, self.data, rotation_matrix=self.rotation_matrix)
82
84
 
83
85
  peak_caller2 = peak_caller(
84
86
  shape=self.data.shape, num_peaks=num_peaks, min_distance=self.min_distance
85
87
  )
86
- peak_caller2(self.data, rotation_matrix=self.rotation_matrix)
88
+ state2 = peak_caller2.init_state()
89
+ state2 = peak_caller2(state2, self.data, rotation_matrix=self.rotation_matrix)
87
90
 
88
- parameters = [tuple(peak_caller1), tuple(peak_caller2)]
89
-
90
- result = tuple(
91
- peak_caller.merge(
92
- candidates=parameters,
93
- num_peaks=num_peaks,
94
- min_distance=self.min_distance,
95
- )
91
+ states = [peak_caller1.result(state1), peak_caller2.result(state2)]
92
+ result = peak_caller.merge(
93
+ results=states,
94
+ num_peaks=num_peaks,
95
+ min_distance=self.min_distance,
96
96
  )
97
97
  assert [len(res) == 2 for res in result]
98
98
 
@@ -122,7 +122,9 @@ class TestRecursiveMasking:
122
122
  rotation_space = self.rotation_space
123
123
  rotation_mapping = self.rotation_mapping
124
124
 
125
- peak_caller(
125
+ state = peak_caller.init_state()
126
+ state = peak_caller(
127
+ state,
126
128
  self.data.copy(),
127
129
  rotation_matrix=self.rotation_matrix,
128
130
  mask=self.mask,
@@ -130,11 +132,10 @@ class TestRecursiveMasking:
130
132
  rotation_mapping=rotation_mapping,
131
133
  )
132
134
 
133
- candidates = tuple(peak_caller)
134
135
  if minimum_score is None:
135
- assert len(candidates[0] <= num_peaks)
136
+ assert len(state[0] <= num_peaks)
136
137
  else:
137
- peaks = candidates[0].astype(int)
138
+ peaks = state[0].astype(int)
138
139
  assert np.all(self.data[tuple(peaks.T)] >= minimum_score)
139
140
 
140
141
 
@@ -157,8 +158,9 @@ class TestMaxScoreOverRotations:
157
158
  shape=self.data.shape,
158
159
  use_memmap=use_memmap,
159
160
  )
160
- score_analyzer(self.data, rotation_matrix=self.rotation_matrix)
161
- res = tuple(score_analyzer)
161
+ state = score_analyzer.init_state()
162
+ state = score_analyzer(state, self.data, rotation_matrix=self.rotation_matrix)
163
+ res = score_analyzer.result(state)
162
164
  assert np.allclose(res[0].shape, self.data.shape)
163
165
  assert res[0].dtype == be._float_dtype
164
166
  assert res[1].size == self.data.ndim
@@ -174,11 +176,13 @@ class TestMaxScoreOverRotations:
174
176
  translation_offset=np.zeros(self.data.ndim, dtype=int),
175
177
  use_memmap=use_memmap,
176
178
  )
177
- score_analyzer(self.data, rotation_matrix=self.rotation_matrix)
179
+ state = score_analyzer.init_state()
180
+ state = score_analyzer(state, self.data, rotation_matrix=self.rotation_matrix)
178
181
 
179
182
  data2 = self.data * 2
180
- score_analyzer(data2, rotation_matrix=self.rotation_matrix)
181
- scores, translation_offset, rotations, mapping = tuple(score_analyzer)
183
+ score_analyzer(state, data2, rotation_matrix=self.rotation_matrix)
184
+ scores, translation_offset, rotations, mapping = score_analyzer.result(state)
185
+
182
186
  assert np.all(scores >= score_threshold)
183
187
  max_scores = np.maximum(self.data, data2)
184
188
  max_scores = np.maximum(max_scores, score_threshold)
@@ -193,7 +197,8 @@ class TestMaxScoreOverRotations:
193
197
  translation_offset=np.zeros(self.data.ndim, dtype=int),
194
198
  use_memmap=use_memmap,
195
199
  )
196
- score_analyzer(self.data, rotation_matrix=self.rotation_matrix)
200
+ state1 = score_analyzer.init_state()
201
+ state1, score_analyzer(state1, self.data, rotation_matrix=self.rotation_matrix)
197
202
 
198
203
  data2 = self.data * 2
199
204
  score_analyzer2 = MaxScoreOverRotations(
@@ -202,12 +207,12 @@ class TestMaxScoreOverRotations:
202
207
  translation_offset=np.zeros(self.data.ndim, dtype=int),
203
208
  use_memmap=use_memmap,
204
209
  )
205
- score_analyzer2(data2, rotation_matrix=self.rotation_matrix)
206
-
207
- parameters = [tuple(score_analyzer), tuple(score_analyzer2)]
210
+ state2 = score_analyzer2.init_state()
211
+ state2 = score_analyzer2(state2, data2, rotation_matrix=self.rotation_matrix)
212
+ states = [score_analyzer.result(state1), score_analyzer2.result(state2)]
208
213
 
209
214
  ret = MaxScoreOverRotations.merge(
210
- parameters, use_memmap=use_memmap, score_threshold=score_threshold
215
+ results=states, use_memmap=use_memmap, score_threshold=score_threshold
211
216
  )
212
217
  scores, translation, rotations, mapping = ret
213
218
  assert np.all(scores >= score_threshold)
tests/test_backends.py CHANGED
@@ -423,6 +423,7 @@ class TestBackends:
423
423
  rotation_matrix=rotation_matrix,
424
424
  out=out,
425
425
  out_mask=out_mask,
426
+ batched=True,
426
427
  )
427
428
 
428
429
  arr_b = backend.to_backend_array(arr_b)
@@ -6,26 +6,10 @@ from os import remove, makedirs
6
6
 
7
7
  import pytest
8
8
  import numpy as np
9
- from tme import Density
9
+ from tme import Density, Orientations
10
10
  from tme.backends import backend as be
11
11
 
12
- BACKEND_CLASSES = ["NumpyFFTWBackend", "PytorchBackend", "CupyBackend", "MLXBackend"]
13
- BACKENDS_TO_TEST = []
14
-
15
- test_gpu = (False,)
16
- for backend_class in BACKEND_CLASSES:
17
- try:
18
- BackendClass = getattr(
19
- __import__("tme.backends", fromlist=[backend_class]), backend_class
20
- )
21
- BACKENDS_TO_TEST.append(BackendClass(device="cpu"))
22
- if backend_class == "CupyBackend":
23
- if BACKENDS_TO_TEST[-1].device_count() >= 1:
24
- test_gpu = (False, True)
25
- except ImportError:
26
- print(f"Couldn't import {backend_class}. Skipping...")
27
-
28
-
12
+ np.random.seed(42)
29
13
  available_backends = (x for x in be.available_backends() if x != "mlx")
30
14
 
31
15
 
@@ -48,7 +32,6 @@ def argdict_to_command(input_args, executable: str):
48
32
  class TestMatchTemplate:
49
33
  @classmethod
50
34
  def setup_class(cls):
51
- np.random.seed(42)
52
35
  target = np.random.rand(20, 20, 20)
53
36
  template = np.random.rand(5, 5, 5)
54
37
 
@@ -65,6 +48,19 @@ class TestMatchTemplate:
65
48
  cls.template_mask_path = tempfile.NamedTemporaryFile(
66
49
  delete=False, suffix=".mrc"
67
50
  ).name
51
+ cls.tempdir = tempfile.TemporaryDirectory().name
52
+ makedirs(cls.tempdir, exist_ok=True)
53
+
54
+ orientations = Orientations(
55
+ translations=((10, 10, 10), (12, 10, 15)),
56
+ rotations=((0, 0, 0), (45, 12, 90)),
57
+ scores=(0, 0),
58
+ details=(-1, -1),
59
+ )
60
+ cls.orientations_path = tempfile.NamedTemporaryFile(
61
+ delete=False, suffix=".star"
62
+ ).name
63
+ orientations.to_file(cls.orientations_path)
68
64
 
69
65
  Density(target, sampling_rate=5).to_file(cls.target_path)
70
66
  Density(template, sampling_rate=5).to_file(cls.template_path)
@@ -76,6 +72,8 @@ class TestMatchTemplate:
76
72
  cls.try_delete(cls.template_path)
77
73
  cls.try_delete(cls.target_mask_path)
78
74
  cls.try_delete(cls.template_mask_path)
75
+ cls.try_delete(cls.orientations_path)
76
+ cls.try_delete(cls.tempdir)
79
77
 
80
78
  @staticmethod
81
79
  def try_delete(file_path: str):
@@ -88,8 +86,8 @@ class TestMatchTemplate:
88
86
  except Exception:
89
87
  pass
90
88
 
91
- @staticmethod
92
89
  def run_matching(
90
+ self,
93
91
  use_template_mask: bool,
94
92
  test_filter: bool,
95
93
  call_peaks: bool,
@@ -99,6 +97,7 @@ class TestMatchTemplate:
99
97
  target_mask_path: str,
100
98
  use_target_mask: bool = False,
101
99
  backend: str = "numpyfftw",
100
+ test_rejection_sampling: bool = False,
102
101
  ):
103
102
  output_path = tempfile.NamedTemporaryFile(delete=False, suffix="pickle").name
104
103
 
@@ -108,25 +107,24 @@ class TestMatchTemplate:
108
107
  "-n": 1,
109
108
  "-a": 60,
110
109
  "-o": output_path,
111
- "--pad_edges": False,
112
- "--pad_fourier": False,
110
+ "--pad-edges": False,
113
111
  "--backend": backend,
114
112
  }
115
113
 
116
114
  if use_template_mask:
117
- argdict["--template_mask"] = template_mask_path
115
+ argdict["--template-mask"] = template_mask_path
118
116
 
119
117
  if use_target_mask:
120
- argdict["--target_mask"] = target_mask_path
118
+ argdict["--target-mask"] = target_mask_path
121
119
 
122
- if backend in ("cupy", "pytorch") and True in test_gpu:
123
- argdict["--use_gpu"] = True
120
+ if test_rejection_sampling:
121
+ argdict["--orientations"] = self.orientations_path
124
122
 
125
123
  if test_filter:
126
124
  argdict["--lowpass"] = 30
127
125
  argdict["--defocus"] = 3000
128
- argdict["--tilt_angles"] = "40,40:10"
129
- argdict["--wedge_axes"] = "0,2"
126
+ argdict["--tilt-angles"] = "40,40"
127
+ argdict["--wedge-axes"] = "2,0"
130
128
  argdict["--whiten"] = True
131
129
 
132
130
  if call_peaks:
@@ -142,13 +140,18 @@ class TestMatchTemplate:
142
140
  @pytest.mark.parametrize("call_peaks", (False, True))
143
141
  @pytest.mark.parametrize("use_template_mask", (False, True))
144
142
  @pytest.mark.parametrize("test_filter", (False, True))
143
+ @pytest.mark.parametrize("test_rejection_sampling", (False, True))
145
144
  def test_match_template(
146
145
  self,
147
146
  backend: bool,
148
147
  call_peaks: bool,
149
148
  use_template_mask: bool,
150
149
  test_filter: bool,
150
+ test_rejection_sampling: bool,
151
151
  ):
152
+ if backend == "jax" and (call_peaks or test_rejection_sampling):
153
+ return None
154
+
152
155
  self.run_matching(
153
156
  use_template_mask=use_template_mask,
154
157
  use_target_mask=True,
@@ -159,6 +162,7 @@ class TestMatchTemplate:
159
162
  target_path=self.target_path,
160
163
  template_mask_path=self.template_mask_path,
161
164
  target_mask_path=self.target_mask_path,
165
+ test_rejection_sampling=test_rejection_sampling,
162
166
  )
163
167
 
164
168
 
@@ -175,20 +179,20 @@ class TestPostprocessing(TestMatchTemplate):
175
179
  "target_path": cls.target_path,
176
180
  "template_mask_path": cls.template_mask_path,
177
181
  "target_mask_path": cls.target_mask_path,
182
+ "test_rejection_sampling": False,
178
183
  }
179
184
 
180
185
  cls.score_pickle = cls.run_matching(
186
+ cls,
181
187
  call_peaks=False,
182
188
  **matching_kwargs,
183
189
  )
184
- cls.peak_pickle = cls.run_matching(call_peaks=True, **matching_kwargs)
185
- cls.tempdir = tempfile.TemporaryDirectory().name
190
+ cls.peak_pickle = cls.run_matching(cls, call_peaks=True, **matching_kwargs)
186
191
 
187
192
  @classmethod
188
193
  def teardown_class(cls):
189
194
  cls.try_delete(cls.score_pickle)
190
195
  cls.try_delete(cls.peak_pickle)
191
- cls.try_delete(cls.tempdir)
192
196
 
193
197
  @pytest.mark.parametrize("distance_cutoff_strategy", (0, 1, 2, 3))
194
198
  @pytest.mark.parametrize("score_cutoff", (None, (1,), (0, 1), (None, 1), (0, None)))
@@ -203,28 +207,28 @@ class TestPostprocessing(TestMatchTemplate):
203
207
  makedirs(self.tempdir, exist_ok=True)
204
208
 
205
209
  argdict = {
206
- "--input_file": self.score_pickle,
207
- "--output_format": "orientations",
208
- "--output_prefix": f"{self.tempdir}/temp",
209
- "--peak_oversampling": peak_oversampling,
210
- "--num_peaks": 3,
210
+ "--input-file": self.score_pickle,
211
+ "--output-format": "orientations",
212
+ "--output-prefix": f"{self.tempdir}/temp",
213
+ "--peak-oversampling": peak_oversampling,
214
+ "--num-peaks": 3,
211
215
  }
212
216
 
213
217
  if score_cutoff is not None:
214
218
  if len(score_cutoff) == 1:
215
- argdict["--n_false_positives"] = 1
219
+ argdict["--n-false-positives"] = 1
216
220
  else:
217
221
  min_score, max_score = score_cutoff
218
- argdict["--minimum_score"] = min_score
219
- argdict["--maximum_score"] = max_score
222
+ argdict["--min-score"] = min_score
223
+ argdict["--max-score"] = max_score
220
224
 
221
225
  match distance_cutoff_strategy:
222
226
  case 1:
223
- argdict["--mask_edges"] = True
227
+ argdict["--mask-edges"] = True
224
228
  case 2:
225
- argdict["--min_distance"] = 5
229
+ argdict["--min-distance"] = 5
226
230
  case 3:
227
- argdict["--min_boundary_distance"] = 5
231
+ argdict["--min-boundary-distance"] = 5
228
232
 
229
233
  cmd = argdict_to_command(argdict, executable="postprocess.py")
230
234
  ret = subprocess.run(cmd, capture_output=True, shell=True)
@@ -244,14 +248,15 @@ class TestPostprocessing(TestMatchTemplate):
244
248
  input_file = self.peak_pickle
245
249
 
246
250
  argdict = {
247
- "--input_file": input_file,
248
- "--output_format": output_format,
249
- "--output_prefix": f"{self.tempdir}/temp",
250
- "--num_peaks": 3,
251
- "--peak_caller": "PeakCallerMaximumFilter",
251
+ "--input-file": input_file,
252
+ "--output-format": output_format,
253
+ "--output-prefix": f"{self.tempdir}/temp",
254
+ "--num-peaks": 3,
255
+ "--peak-caller": "PeakCallerMaximumFilter",
252
256
  }
253
257
  cmd = argdict_to_command(argdict, executable="postprocess.py")
254
258
  ret = subprocess.run(cmd, capture_output=True, shell=True)
259
+ print(ret)
255
260
 
256
261
  match output_format:
257
262
  case "orientations":
@@ -264,7 +269,8 @@ class TestPostprocessing(TestMatchTemplate):
264
269
  assert exists(f"{self.tempdir}/temp.star")
265
270
  case "relion5":
266
271
  assert exists(f"{self.tempdir}/temp.star")
267
-
272
+ case "pickle":
273
+ assert exists(f"{self.tempdir}/temp.pickle")
268
274
  assert ret.returncode == 0
269
275
 
270
276
  def test_postprocess_score_local_optimization(self):
@@ -272,12 +278,62 @@ class TestPostprocessing(TestMatchTemplate):
272
278
  makedirs(self.tempdir, exist_ok=True)
273
279
 
274
280
  argdict = {
275
- "--input_file": self.score_pickle,
276
- "--output_format": "orientations",
277
- "--output_prefix": f"{self.tempdir}/temp",
278
- "--num_peaks": 1,
279
- "--local_optimization": True,
281
+ "--input-file": self.score_pickle,
282
+ "--output-format": "orientations",
283
+ "--output-prefix": f"{self.tempdir}/temp",
284
+ "--num-peaks": 1,
285
+ "--local-optimization": True,
280
286
  }
281
287
  cmd = argdict_to_command(argdict, executable="postprocess.py")
282
288
  ret = subprocess.run(cmd, capture_output=True, shell=True)
283
289
  assert ret.returncode == 0
290
+
291
+
292
+ class TestEstimateMemoryUsage(TestMatchTemplate):
293
+ @classmethod
294
+ def setup_class(cls):
295
+ super().setup_class()
296
+
297
+ @pytest.mark.parametrize("ncores", (1, 4, 8))
298
+ @pytest.mark.parametrize("pad_edges", (False, True))
299
+ def test_estimation(self, ncores, pad_edges):
300
+
301
+ argdict = {
302
+ "-m": self.target_path,
303
+ "-i": self.template_path,
304
+ "--ncores": ncores,
305
+ "--pad-edges": pad_edges,
306
+ "--score": "FLCSphericalMask",
307
+ }
308
+
309
+ cmd = argdict_to_command(argdict, executable="estimate_memory_usage.py")
310
+ ret = subprocess.run(cmd, capture_output=True, shell=True)
311
+ assert ret.returncode == 0
312
+
313
+
314
+ class TestPreprocess(TestMatchTemplate):
315
+ @classmethod
316
+ def setup_class(cls):
317
+ super().setup_class()
318
+
319
+ @pytest.mark.parametrize("backend", available_backends)
320
+ @pytest.mark.parametrize("align_axis", (False, True))
321
+ @pytest.mark.parametrize("invert_contrast", (False, True))
322
+ def test_estimation(self, backend, align_axis, invert_contrast):
323
+
324
+ argdict = {
325
+ "-m": self.target_path,
326
+ "--backend": backend,
327
+ "--lowpass": 40,
328
+ "--sampling-rate": 5,
329
+ "-o": f"{self.tempdir}/out.mrc",
330
+ }
331
+ if align_axis:
332
+ argdict["--align-axis"] = 2
333
+
334
+ if invert_contrast:
335
+ argdict["--invert-contrast"] = True
336
+
337
+ cmd = argdict_to_command(argdict, executable="preprocess.py")
338
+ ret = subprocess.run(cmd, capture_output=True, shell=True)
339
+ assert ret.returncode == 0
@@ -7,7 +7,7 @@ from tme.backends import backend as be
7
7
  from tme.matching_data import MatchingData
8
8
 
9
9
 
10
- class TestDensity:
10
+ class TestMatchingData:
11
11
  def setup_method(self):
12
12
  target = np.zeros((50, 50, 50))
13
13
  target[20:30, 30:40, 12:17] = 1
@@ -87,9 +87,9 @@ class TestDensity:
87
87
  matching_data.target_mask = self.target
88
88
  matching_data.template_mask = self.template
89
89
 
90
- ret = matching_data.subset_by_slice()
90
+ ret, offset = matching_data.subset_by_slice()
91
91
 
92
- assert type(ret) == type(matching_data)
92
+ assert isinstance(ret, type(matching_data))
93
93
  assert np.allclose(ret.target, matching_data.target)
94
94
  assert np.allclose(ret.template, matching_data.template)
95
95
  assert np.allclose(ret.target_mask, matching_data.target_mask)
@@ -107,10 +107,10 @@ class TestDensity:
107
107
  template_slice = MatchingData._shape_to_slice(
108
108
  shape=np.divide(self.template.shape, 2).astype(int)
109
109
  )
110
- ret = matching_data.subset_by_slice(
110
+ ret, offset = matching_data.subset_by_slice(
111
111
  target_slice=target_slice, template_slice=template_slice
112
112
  )
113
- assert type(ret) == type(matching_data)
113
+ assert isinstance(ret, type(matching_data))
114
114
 
115
115
  assert np.allclose(
116
116
  ret.target.shape, np.divide(self.target.shape, 2).astype(int)
@@ -7,7 +7,6 @@ from tme.matching_data import MatchingData
7
7
  from tme.memory import MATCHING_MEMORY_REGISTRY
8
8
  from tme.analyzer import MaxScoreOverRotations, PeakCallerSort
9
9
  from tme.matching_exhaustive import (
10
- scan,
11
10
  scan_subsets,
12
11
  MATCHING_EXHAUSTIVE_REGISTER,
13
12
  register_matching_exhaustive,
@@ -36,7 +35,7 @@ class TestMatchExhaustive:
36
35
  self.coordinates_weights = None
37
36
  self.rotations = None
38
37
 
39
- @pytest.mark.parametrize("evaluate_peak", (False, True))
38
+ @pytest.mark.parametrize("evaluate_peak", (True,))
40
39
  @pytest.mark.parametrize("score", tuple(MATCHING_EXHAUSTIVE_REGISTER.keys()))
41
40
  @pytest.mark.parametrize("job_schedule", ((2, 1),))
42
41
  @pytest.mark.parametrize("pad_edge", (False, True))
@@ -24,6 +24,7 @@ coordinate_to_coordinate = [
24
24
  for k, v in MATCHING_OPTIMIZATION_REGISTER.items()
25
25
  if issubclass(v, _MatchCoordinatesToCoordinates)
26
26
  ]
27
+ np.random.seed(42)
27
28
 
28
29
 
29
30
  class TestMatchDensityToDensity:
@@ -52,9 +53,7 @@ class TestMatchDensityToDensity:
52
53
 
53
54
  @pytest.mark.parametrize("method", density_to_density)
54
55
  def test_call(self, method):
55
- instance = self.test_initialization(method=method, notest=True)
56
- score = instance()
57
- assert isinstance(score, float)
56
+ self.test_initialization(method=method, notest=True)()
58
57
 
59
58
 
60
59
  class TestMatchDensityToCoordinates:
@@ -97,9 +96,7 @@ class TestMatchDensityToCoordinates:
97
96
 
98
97
  @pytest.mark.parametrize("method", coordinate_to_density)
99
98
  def test_call(self, method):
100
- instance = self.test_initialization(method=method, notest=True)
101
- score = instance()
102
- assert isinstance(score, float)
99
+ self.test_initialization(method=method, notest=True)()
103
100
 
104
101
 
105
102
  class TestMatchCoordinateToCoordinates:
@@ -135,9 +132,7 @@ class TestMatchCoordinateToCoordinates:
135
132
 
136
133
  @pytest.mark.parametrize("method", coordinate_to_coordinate)
137
134
  def test_call(self, method):
138
- instance = self.test_initialization(method=method, notest=True)
139
- score = instance()
140
- assert isinstance(score, float)
135
+ self.test_initialization(method=method, notest=True)()
141
136
 
142
137
 
143
138
  class TestOptimizeMatch: