pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/estimate_memory_usage.py +1 -5
  2. {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/match_template.py +163 -201
  3. {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/postprocess.py +48 -39
  4. {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/preprocess.py +10 -23
  5. {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/preprocessor_gui.py +3 -4
  6. pytme-0.3b0.post1.data/scripts/pytme_runner.py +769 -0
  7. {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/METADATA +14 -14
  8. {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/RECORD +54 -50
  9. {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/entry_points.txt +1 -0
  10. pytme-0.3b0.post1.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +1 -5
  12. scripts/eval.py +93 -0
  13. scripts/match_template.py +163 -201
  14. scripts/match_template_filters.py +1200 -0
  15. scripts/postprocess.py +48 -39
  16. scripts/preprocess.py +10 -23
  17. scripts/preprocessor_gui.py +3 -4
  18. scripts/pytme_runner.py +769 -0
  19. scripts/refine_matches.py +0 -1
  20. tests/preprocessing/test_frequency_filters.py +19 -10
  21. tests/test_analyzer.py +122 -122
  22. tests/test_backends.py +1 -0
  23. tests/test_matching_cli.py +30 -30
  24. tests/test_matching_data.py +5 -5
  25. tests/test_matching_utils.py +1 -1
  26. tme/__version__.py +1 -1
  27. tme/analyzer/__init__.py +1 -1
  28. tme/analyzer/_utils.py +1 -4
  29. tme/analyzer/aggregation.py +15 -6
  30. tme/analyzer/base.py +25 -36
  31. tme/analyzer/peaks.py +39 -113
  32. tme/analyzer/proxy.py +1 -0
  33. tme/backends/_jax_utils.py +16 -15
  34. tme/backends/cupy_backend.py +9 -13
  35. tme/backends/jax_backend.py +19 -16
  36. tme/backends/npfftw_backend.py +27 -25
  37. tme/backends/pytorch_backend.py +4 -0
  38. tme/density.py +5 -4
  39. tme/filters/__init__.py +2 -2
  40. tme/filters/_utils.py +32 -7
  41. tme/filters/bandpass.py +225 -186
  42. tme/filters/ctf.py +117 -67
  43. tme/filters/reconstruction.py +38 -9
  44. tme/filters/wedge.py +88 -105
  45. tme/filters/whitening.py +1 -6
  46. tme/matching_data.py +24 -36
  47. tme/matching_exhaustive.py +14 -11
  48. tme/matching_scores.py +21 -12
  49. tme/matching_utils.py +13 -6
  50. tme/orientations.py +13 -3
  51. tme/parser.py +109 -29
  52. tme/preprocessor.py +2 -2
  53. pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
  54. {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/WHEEL +0 -0
  55. {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/top_level.txt +0 -0
scripts/refine_matches.py CHANGED
@@ -494,7 +494,6 @@ class DeepMatcher:
494
494
 
495
495
 
496
496
  def main():
497
- print("Entered")
498
497
  args = parse_args()
499
498
 
500
499
  if args.input_file is not None:
@@ -4,13 +4,15 @@ from typing import Tuple
4
4
 
5
5
  from tme.backends import backend as be
6
6
  from tme.filters._utils import compute_fourier_shape
7
- from tme.filters import BandPassFilter, LinearWhiteningFilter
7
+ from tme.filters import BandPassReconstructed, LinearWhiteningFilter
8
+ from tme.filters.bandpass import gaussian_bandpass, discrete_bandpass
9
+ from tme.filters._utils import fftfreqn
8
10
 
9
11
 
10
12
  class TestBandPassFilter:
11
13
  @pytest.fixture
12
14
  def band_pass_filter(self):
13
- return BandPassFilter()
15
+ return BandPassReconstructed()
14
16
 
15
17
  @pytest.mark.parametrize(
16
18
  "shape, lowpass, highpass, sampling_rate",
@@ -24,9 +26,13 @@ class TestBandPassFilter:
24
26
  def test_discrete_bandpass(
25
27
  self, shape: Tuple[int], lowpass: float, highpass: float, sampling_rate: float
26
28
  ):
27
- result = BandPassFilter.discrete_bandpass(
28
- shape, lowpass, highpass, sampling_rate
29
+ grid = fftfreqn(
30
+ shape=shape,
31
+ sampling_rate=0.5,
32
+ shape_is_real_fourier=False,
33
+ compute_euclidean_norm=True,
29
34
  )
35
+ result = discrete_bandpass(grid, lowpass, highpass, sampling_rate)
30
36
  assert isinstance(result, type(be.ones((1,))))
31
37
  assert result.shape == shape
32
38
  assert np.all((result >= 0) & (result <= 1))
@@ -43,9 +49,13 @@ class TestBandPassFilter:
43
49
  def test_gaussian_bandpass(
44
50
  self, shape: Tuple[int], lowpass: float, highpass: float, sampling_rate: float
45
51
  ):
46
- result = BandPassFilter.gaussian_bandpass(
47
- shape, lowpass, highpass, sampling_rate
52
+ grid = fftfreqn(
53
+ shape=shape,
54
+ sampling_rate=0.5,
55
+ shape_is_real_fourier=False,
56
+ compute_euclidean_norm=True,
48
57
  )
58
+ result = gaussian_bandpass(grid, lowpass, highpass, sampling_rate)
49
59
  assert isinstance(result, type(be.ones((1,))))
50
60
  assert result.shape == shape
51
61
  assert np.all((result >= 0) & (result <= 1))
@@ -55,7 +65,7 @@ class TestBandPassFilter:
55
65
  @pytest.mark.parametrize("shape_is_real_fourier", [True, False])
56
66
  def test_call_method(
57
67
  self,
58
- band_pass_filter: BandPassFilter,
68
+ band_pass_filter: BandPassReconstructed,
59
69
  use_gaussian: bool,
60
70
  return_real_fourier: bool,
61
71
  shape_is_real_fourier: bool,
@@ -72,17 +82,16 @@ class TestBandPassFilter:
72
82
  assert isinstance(result["data"], type(be.ones((1,))))
73
83
  assert result["is_multiplicative_filter"] is True
74
84
 
75
- def test_default_values(self, band_pass_filter: BandPassFilter):
85
+ def test_default_values(self, band_pass_filter: BandPassReconstructed):
76
86
  assert band_pass_filter.lowpass is None
77
87
  assert band_pass_filter.highpass is None
78
88
  assert band_pass_filter.sampling_rate == 1
79
89
  assert band_pass_filter.use_gaussian is True
80
90
  assert band_pass_filter.return_real_fourier is False
81
- assert band_pass_filter.shape_is_real_fourier is False
82
91
 
83
92
  @pytest.mark.parametrize("shape", ((10, 10), (20, 20, 20), (30, 30)))
84
93
  def test_return_real_fourier(self, shape: Tuple[int]):
85
- bpf = BandPassFilter(return_real_fourier=True)
94
+ bpf = BandPassReconstructed(return_real_fourier=True)
86
95
  result = bpf(shape=shape, lowpass=0.2, highpass=0.8)
87
96
  expected_shape = tuple(compute_fourier_shape(shape, False))
88
97
  assert result["data"].shape == expected_shape
tests/test_analyzer.py CHANGED
@@ -97,125 +97,125 @@ class TestPeakCallers:
97
97
  assert [len(res) == 2 for res in result]
98
98
 
99
99
 
100
- # class TestRecursiveMasking:
101
- # def setup_method(self):
102
- # self.num_peaks = 100
103
- # self.min_distance = 5
104
- # self.data = np.random.rand(100, 100, 100)
105
- # self.rotation_matrix = np.eye(3)
106
- # self.mask = np.random.rand(20, 20, 20)
107
- # self.rotation_space = np.zeros_like(self.data)
108
- # self.rotation_mapping = {0: (0, 0, 0)}
109
-
110
- # @pytest.mark.parametrize("num_peaks", (1, 100))
111
- # @pytest.mark.parametrize("compute_rotation", (True, False))
112
- # @pytest.mark.parametrize("minimum_score", (None, 0.5))
113
- # def test__call__(self, num_peaks, compute_rotation, minimum_score):
114
- # peak_caller = PeakCallerRecursiveMasking(
115
- # shape=self.data.shape,
116
- # num_peaks=num_peaks,
117
- # min_distance=self.min_distance,
118
- # min_score=minimum_score,
119
- # )
120
- # rotation_space, rotation_mapping = None, None
121
- # if compute_rotation:
122
- # rotation_space = self.rotation_space
123
- # rotation_mapping = self.rotation_mapping
124
-
125
- # state = peak_caller.init_state()
126
- # state = peak_caller(
127
- # state,
128
- # self.data.copy(),
129
- # rotation_matrix=self.rotation_matrix,
130
- # mask=self.mask,
131
- # rotation_space=rotation_space,
132
- # rotation_mapping=rotation_mapping,
133
- # )
134
-
135
- # if minimum_score is None:
136
- # assert len(state[0] <= num_peaks)
137
- # else:
138
- # peaks = state[0].astype(int)
139
- # assert np.all(self.data[tuple(peaks.T)] >= minimum_score)
140
-
141
-
142
- # class TestMaxScoreOverRotations:
143
- # def setup_method(self):
144
- # self.num_peaks = 100
145
- # self.min_distance = 5
146
- # self.data = np.random.rand(100, 100, 100)
147
- # self.rotation_matrix = np.eye(3)
148
-
149
- # def test_initialization(self):
150
- # _ = MaxScoreOverRotations(
151
- # shape=self.data.shape,
152
- # translation_offset=np.zeros(self.data.ndim, dtype=int),
153
- # )
154
-
155
- # @pytest.mark.parametrize("use_memmap", [False, True])
156
- # def test__iter__(self, use_memmap: bool):
157
- # score_analyzer = MaxScoreOverRotations(
158
- # shape=self.data.shape,
159
- # use_memmap=use_memmap,
160
- # )
161
- # state = score_analyzer.init_state()
162
- # state = score_analyzer(state, self.data, rotation_matrix=self.rotation_matrix)
163
- # res = score_analyzer.result(state)
164
- # assert np.allclose(res[0].shape, self.data.shape)
165
- # assert res[0].dtype == be._float_dtype
166
- # assert res[1].size == self.data.ndim
167
- # assert np.allclose(res[2].shape, self.data.shape)
168
- # assert len(res) == 4
169
-
170
- # @pytest.mark.parametrize("use_memmap", [False, True])
171
- # @pytest.mark.parametrize("score_threshold", [0, 1e10, -1e10])
172
- # def test__call__(self, use_memmap: bool, score_threshold: float):
173
- # score_analyzer = MaxScoreOverRotations(
174
- # shape=self.data.shape,
175
- # score_threshold=score_threshold,
176
- # translation_offset=np.zeros(self.data.ndim, dtype=int),
177
- # use_memmap=use_memmap,
178
- # )
179
- # state = score_analyzer.init_state()
180
- # state = score_analyzer(state, self.data, rotation_matrix=self.rotation_matrix)
181
-
182
- # data2 = self.data * 2
183
- # score_analyzer(state, data2, rotation_matrix=self.rotation_matrix)
184
- # scores, translation_offset, rotations, mapping = score_analyzer.result(state)
185
-
186
- # assert np.all(scores >= score_threshold)
187
- # max_scores = np.maximum(self.data, data2)
188
- # max_scores = np.maximum(max_scores, score_threshold)
189
- # assert np.allclose(scores, max_scores)
190
-
191
- # @pytest.mark.parametrize("use_memmap", [False, True])
192
- # @pytest.mark.parametrize("score_threshold", [0, 1e10, -1e10])
193
- # def test_merge(self, use_memmap: bool, score_threshold: float):
194
- # score_analyzer = MaxScoreOverRotations(
195
- # shape=self.data.shape,
196
- # score_threshold=score_threshold,
197
- # translation_offset=np.zeros(self.data.ndim, dtype=int),
198
- # use_memmap=use_memmap,
199
- # )
200
- # state1 = score_analyzer.init_state()
201
- # state1, score_analyzer(state1, self.data, rotation_matrix=self.rotation_matrix)
202
-
203
- # data2 = self.data * 2
204
- # score_analyzer2 = MaxScoreOverRotations(
205
- # shape=self.data.shape,
206
- # score_threshold=score_threshold,
207
- # translation_offset=np.zeros(self.data.ndim, dtype=int),
208
- # use_memmap=use_memmap,
209
- # )
210
- # state2 = score_analyzer2.init_state()
211
- # state2 = score_analyzer2(state2, data2, rotation_matrix=self.rotation_matrix)
212
- # states = [score_analyzer.result(state1), score_analyzer2.result(state2)]
213
-
214
- # ret = MaxScoreOverRotations.merge(
215
- # results=states, use_memmap=use_memmap, score_threshold=score_threshold
216
- # )
217
- # scores, translation, rotations, mapping = ret
218
- # assert np.all(scores >= score_threshold)
219
- # max_scores = np.maximum(self.data, data2)
220
- # max_scores = np.maximum(max_scores, score_threshold)
221
- # assert np.allclose(scores, max_scores)
100
+ class TestRecursiveMasking:
101
+ def setup_method(self):
102
+ self.num_peaks = 100
103
+ self.min_distance = 5
104
+ self.data = np.random.rand(100, 100, 100)
105
+ self.rotation_matrix = np.eye(3)
106
+ self.mask = np.random.rand(20, 20, 20)
107
+ self.rotation_space = np.zeros_like(self.data)
108
+ self.rotation_mapping = {0: (0, 0, 0)}
109
+
110
+ @pytest.mark.parametrize("num_peaks", (1, 100))
111
+ @pytest.mark.parametrize("compute_rotation", (True, False))
112
+ @pytest.mark.parametrize("minimum_score", (None, 0.5))
113
+ def test__call__(self, num_peaks, compute_rotation, minimum_score):
114
+ peak_caller = PeakCallerRecursiveMasking(
115
+ shape=self.data.shape,
116
+ num_peaks=num_peaks,
117
+ min_distance=self.min_distance,
118
+ min_score=minimum_score,
119
+ )
120
+ rotation_space, rotation_mapping = None, None
121
+ if compute_rotation:
122
+ rotation_space = self.rotation_space
123
+ rotation_mapping = self.rotation_mapping
124
+
125
+ state = peak_caller.init_state()
126
+ state = peak_caller(
127
+ state,
128
+ self.data.copy(),
129
+ rotation_matrix=self.rotation_matrix,
130
+ mask=self.mask,
131
+ rotation_space=rotation_space,
132
+ rotation_mapping=rotation_mapping,
133
+ )
134
+
135
+ if minimum_score is None:
136
+ assert len(state[0] <= num_peaks)
137
+ else:
138
+ peaks = state[0].astype(int)
139
+ assert np.all(self.data[tuple(peaks.T)] >= minimum_score)
140
+
141
+
142
+ class TestMaxScoreOverRotations:
143
+ def setup_method(self):
144
+ self.num_peaks = 100
145
+ self.min_distance = 5
146
+ self.data = np.random.rand(100, 100, 100)
147
+ self.rotation_matrix = np.eye(3)
148
+
149
+ def test_initialization(self):
150
+ _ = MaxScoreOverRotations(
151
+ shape=self.data.shape,
152
+ translation_offset=np.zeros(self.data.ndim, dtype=int),
153
+ )
154
+
155
+ @pytest.mark.parametrize("use_memmap", [False, True])
156
+ def test__iter__(self, use_memmap: bool):
157
+ score_analyzer = MaxScoreOverRotations(
158
+ shape=self.data.shape,
159
+ use_memmap=use_memmap,
160
+ )
161
+ state = score_analyzer.init_state()
162
+ state = score_analyzer(state, self.data, rotation_matrix=self.rotation_matrix)
163
+ res = score_analyzer.result(state)
164
+ assert np.allclose(res[0].shape, self.data.shape)
165
+ assert res[0].dtype == be._float_dtype
166
+ assert res[1].size == self.data.ndim
167
+ assert np.allclose(res[2].shape, self.data.shape)
168
+ assert len(res) == 4
169
+
170
+ @pytest.mark.parametrize("use_memmap", [False, True])
171
+ @pytest.mark.parametrize("score_threshold", [0, 1e10, -1e10])
172
+ def test__call__(self, use_memmap: bool, score_threshold: float):
173
+ score_analyzer = MaxScoreOverRotations(
174
+ shape=self.data.shape,
175
+ score_threshold=score_threshold,
176
+ translation_offset=np.zeros(self.data.ndim, dtype=int),
177
+ use_memmap=use_memmap,
178
+ )
179
+ state = score_analyzer.init_state()
180
+ state = score_analyzer(state, self.data, rotation_matrix=self.rotation_matrix)
181
+
182
+ data2 = self.data * 2
183
+ score_analyzer(state, data2, rotation_matrix=self.rotation_matrix)
184
+ scores, translation_offset, rotations, mapping = score_analyzer.result(state)
185
+
186
+ assert np.all(scores >= score_threshold)
187
+ max_scores = np.maximum(self.data, data2)
188
+ max_scores = np.maximum(max_scores, score_threshold)
189
+ assert np.allclose(scores, max_scores)
190
+
191
+ @pytest.mark.parametrize("use_memmap", [False, True])
192
+ @pytest.mark.parametrize("score_threshold", [0, 1e10, -1e10])
193
+ def test_merge(self, use_memmap: bool, score_threshold: float):
194
+ score_analyzer = MaxScoreOverRotations(
195
+ shape=self.data.shape,
196
+ score_threshold=score_threshold,
197
+ translation_offset=np.zeros(self.data.ndim, dtype=int),
198
+ use_memmap=use_memmap,
199
+ )
200
+ state1 = score_analyzer.init_state()
201
+ state1, score_analyzer(state1, self.data, rotation_matrix=self.rotation_matrix)
202
+
203
+ data2 = self.data * 2
204
+ score_analyzer2 = MaxScoreOverRotations(
205
+ shape=self.data.shape,
206
+ score_threshold=score_threshold,
207
+ translation_offset=np.zeros(self.data.ndim, dtype=int),
208
+ use_memmap=use_memmap,
209
+ )
210
+ state2 = score_analyzer2.init_state()
211
+ state2 = score_analyzer2(state2, data2, rotation_matrix=self.rotation_matrix)
212
+ states = [score_analyzer.result(state1), score_analyzer2.result(state2)]
213
+
214
+ ret = MaxScoreOverRotations.merge(
215
+ results=states, use_memmap=use_memmap, score_threshold=score_threshold
216
+ )
217
+ scores, translation, rotations, mapping = ret
218
+ assert np.all(scores >= score_threshold)
219
+ max_scores = np.maximum(self.data, data2)
220
+ max_scores = np.maximum(max_scores, score_threshold)
221
+ assert np.allclose(scores, max_scores)
tests/test_backends.py CHANGED
@@ -423,6 +423,7 @@ class TestBackends:
423
423
  rotation_matrix=rotation_matrix,
424
424
  out=out,
425
425
  out_mask=out_mask,
426
+ batched=True,
426
427
  )
427
428
 
428
429
  arr_b = backend.to_backend_array(arr_b)
@@ -107,15 +107,15 @@ class TestMatchTemplate:
107
107
  "-n": 1,
108
108
  "-a": 60,
109
109
  "-o": output_path,
110
- "--pad_edges": False,
110
+ "--pad-edges": False,
111
111
  "--backend": backend,
112
112
  }
113
113
 
114
114
  if use_template_mask:
115
- argdict["--template_mask"] = template_mask_path
115
+ argdict["--template-mask"] = template_mask_path
116
116
 
117
117
  if use_target_mask:
118
- argdict["--target_mask"] = target_mask_path
118
+ argdict["--target-mask"] = target_mask_path
119
119
 
120
120
  if test_rejection_sampling:
121
121
  argdict["--orientations"] = self.orientations_path
@@ -123,8 +123,8 @@ class TestMatchTemplate:
123
123
  if test_filter:
124
124
  argdict["--lowpass"] = 30
125
125
  argdict["--defocus"] = 3000
126
- argdict["--tilt_angles"] = "40,40"
127
- argdict["--wedge_axes"] = "2,0"
126
+ argdict["--tilt-angles"] = "40,40"
127
+ argdict["--wedge-axes"] = "2,0"
128
128
  argdict["--whiten"] = True
129
129
 
130
130
  if call_peaks:
@@ -207,28 +207,28 @@ class TestPostprocessing(TestMatchTemplate):
207
207
  makedirs(self.tempdir, exist_ok=True)
208
208
 
209
209
  argdict = {
210
- "--input_file": self.score_pickle,
211
- "--output_format": "orientations",
212
- "--output_prefix": f"{self.tempdir}/temp",
213
- "--peak_oversampling": peak_oversampling,
214
- "--num_peaks": 3,
210
+ "--input-file": self.score_pickle,
211
+ "--output-format": "orientations",
212
+ "--output-prefix": f"{self.tempdir}/temp",
213
+ "--peak-oversampling": peak_oversampling,
214
+ "--num-peaks": 3,
215
215
  }
216
216
 
217
217
  if score_cutoff is not None:
218
218
  if len(score_cutoff) == 1:
219
- argdict["--n_false_positives"] = 1
219
+ argdict["--n-false-positives"] = 1
220
220
  else:
221
221
  min_score, max_score = score_cutoff
222
- argdict["--minimum_score"] = min_score
223
- argdict["--maximum_score"] = max_score
222
+ argdict["--min-score"] = min_score
223
+ argdict["--max-score"] = max_score
224
224
 
225
225
  match distance_cutoff_strategy:
226
226
  case 1:
227
- argdict["--mask_edges"] = True
227
+ argdict["--mask-edges"] = True
228
228
  case 2:
229
- argdict["--min_distance"] = 5
229
+ argdict["--min-distance"] = 5
230
230
  case 3:
231
- argdict["--min_boundary_distance"] = 5
231
+ argdict["--min-boundary-distance"] = 5
232
232
 
233
233
  cmd = argdict_to_command(argdict, executable="postprocess.py")
234
234
  ret = subprocess.run(cmd, capture_output=True, shell=True)
@@ -248,11 +248,11 @@ class TestPostprocessing(TestMatchTemplate):
248
248
  input_file = self.peak_pickle
249
249
 
250
250
  argdict = {
251
- "--input_file": input_file,
252
- "--output_format": output_format,
253
- "--output_prefix": f"{self.tempdir}/temp",
254
- "--num_peaks": 3,
255
- "--peak_caller": "PeakCallerMaximumFilter",
251
+ "--input-file": input_file,
252
+ "--output-format": output_format,
253
+ "--output-prefix": f"{self.tempdir}/temp",
254
+ "--num-peaks": 3,
255
+ "--peak-caller": "PeakCallerMaximumFilter",
256
256
  }
257
257
  cmd = argdict_to_command(argdict, executable="postprocess.py")
258
258
  ret = subprocess.run(cmd, capture_output=True, shell=True)
@@ -278,11 +278,11 @@ class TestPostprocessing(TestMatchTemplate):
278
278
  makedirs(self.tempdir, exist_ok=True)
279
279
 
280
280
  argdict = {
281
- "--input_file": self.score_pickle,
282
- "--output_format": "orientations",
283
- "--output_prefix": f"{self.tempdir}/temp",
284
- "--num_peaks": 1,
285
- "--local_optimization": True,
281
+ "--input-file": self.score_pickle,
282
+ "--output-format": "orientations",
283
+ "--output-prefix": f"{self.tempdir}/temp",
284
+ "--num-peaks": 1,
285
+ "--local-optimization": True,
286
286
  }
287
287
  cmd = argdict_to_command(argdict, executable="postprocess.py")
288
288
  ret = subprocess.run(cmd, capture_output=True, shell=True)
@@ -302,7 +302,7 @@ class TestEstimateMemoryUsage(TestMatchTemplate):
302
302
  "-m": self.target_path,
303
303
  "-i": self.template_path,
304
304
  "--ncores": ncores,
305
- "--pad_edges": pad_edges,
305
+ "--pad-edges": pad_edges,
306
306
  "--score": "FLCSphericalMask",
307
307
  }
308
308
 
@@ -325,14 +325,14 @@ class TestPreprocess(TestMatchTemplate):
325
325
  "-m": self.target_path,
326
326
  "--backend": backend,
327
327
  "--lowpass": 40,
328
- "--sampling_rate": 5,
328
+ "--sampling-rate": 5,
329
329
  "-o": f"{self.tempdir}/out.mrc",
330
330
  }
331
331
  if align_axis:
332
- argdict["--align_axis"] = 2
332
+ argdict["--align-axis"] = 2
333
333
 
334
334
  if invert_contrast:
335
- argdict["--invert_contrast"] = True
335
+ argdict["--invert-contrast"] = True
336
336
 
337
337
  cmd = argdict_to_command(argdict, executable="preprocess.py")
338
338
  ret = subprocess.run(cmd, capture_output=True, shell=True)
@@ -7,7 +7,7 @@ from tme.backends import backend as be
7
7
  from tme.matching_data import MatchingData
8
8
 
9
9
 
10
- class TestDensity:
10
+ class TestMatchingData:
11
11
  def setup_method(self):
12
12
  target = np.zeros((50, 50, 50))
13
13
  target[20:30, 30:40, 12:17] = 1
@@ -87,9 +87,9 @@ class TestDensity:
87
87
  matching_data.target_mask = self.target
88
88
  matching_data.template_mask = self.template
89
89
 
90
- ret = matching_data.subset_by_slice()
90
+ ret, offset = matching_data.subset_by_slice()
91
91
 
92
- assert type(ret) == type(matching_data)
92
+ assert isinstance(ret, type(matching_data))
93
93
  assert np.allclose(ret.target, matching_data.target)
94
94
  assert np.allclose(ret.template, matching_data.template)
95
95
  assert np.allclose(ret.target_mask, matching_data.target_mask)
@@ -107,10 +107,10 @@ class TestDensity:
107
107
  template_slice = MatchingData._shape_to_slice(
108
108
  shape=np.divide(self.template.shape, 2).astype(int)
109
109
  )
110
- ret = matching_data.subset_by_slice(
110
+ ret, offset = matching_data.subset_by_slice(
111
111
  target_slice=target_slice, template_slice=template_slice
112
112
  )
113
- assert type(ret) == type(matching_data)
113
+ assert isinstance(ret, type(matching_data))
114
114
 
115
115
  assert np.allclose(
116
116
  ret.target.shape, np.divide(self.target.shape, 2).astype(int)
@@ -139,7 +139,7 @@ class TestMatchingUtils:
139
139
  expected_size = np.subtract(
140
140
  self.density.shape, self.structure_density.shape
141
141
  )
142
- expected_size += np.mod(self.structure_density.shape, 2)
142
+ expected_size += 1
143
143
  assert np.allclose(ret.shape, expected_size)
144
144
 
145
145
  def test_apply_convolution_mode_error(self):
tme/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.3.b0"
1
+ __version__ = "0.3.b0.post1"
tme/analyzer/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
1
  from .peaks import *
2
2
  from .aggregation import *
3
3
  from .proxy import *
4
- from .base import *
4
+ from .base import *
tme/analyzer/_utils.py CHANGED
@@ -47,10 +47,7 @@ def _convmode_to_shape(
47
47
  if convolution_mode == "same":
48
48
  output_shape = targetshape
49
49
  elif convolution_mode == "valid":
50
- output_shape = be.add(
51
- be.subtract(targetshape, templateshape),
52
- be.mod(templateshape, 2),
53
- )
50
+ output_shape = be.subtract(targetshape, templateshape) + 1
54
51
  return be.to_backend_array(output_shape)
55
52
 
56
53
 
@@ -57,22 +57,23 @@ class MaxScoreOverRotations(AbstractAnalyzer):
57
57
  The following achieves the minimal definition of a :py:class:`MaxScoreOverRotations`
58
58
  instance
59
59
 
60
+ >>> import numpy as np
60
61
  >>> from tme.analyzer import MaxScoreOverRotations
61
- >>> analyzer = MaxScoreOverRotations(shape = (50, 50))
62
+ >>> analyzer = MaxScoreOverRotations(shape=(50, 50))
62
63
 
63
64
  The following simulates a template matching run by creating random data for a range
64
65
  of rotations and sending it to ``analyzer`` via its __call__ method
65
66
 
66
- >> state = analyzer.init_state()
67
+ >>> state = analyzer.init_state()
67
68
  >>> for rotation_number in range(10):
68
69
  >>> scores = np.random.rand(50,50)
69
70
  >>> rotation = np.random.rand(scores.ndim, scores.ndim)
70
- >>> state, analyzer(state, scores = scores, rotation_matrix = rotation)
71
+ >>> state = analyzer(state, scores=scores, rotation_matrix=rotation)
71
72
 
72
73
  The aggregated scores can be extracted by invoking the result method of
73
74
  ``analyzer``
74
75
 
75
- >>> results = analyzer.result()
76
+ >>> results = analyzer.result(state)
76
77
 
77
78
  The ``results`` tuple contains (1) the maximum scores for each translation,
78
79
  (2) an offset which is relevant when merging results from split template matching
@@ -100,6 +101,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
100
101
  shm_handler: object = None,
101
102
  use_memmap: bool = False,
102
103
  inversion_mapping: bool = False,
104
+ jax_mode: bool = False,
103
105
  **kwargs,
104
106
  ):
105
107
  self._use_memmap = use_memmap
@@ -107,6 +109,10 @@ class MaxScoreOverRotations(AbstractAnalyzer):
107
109
  self._shape = tuple(int(x) for x in shape)
108
110
  self._inversion_mapping = inversion_mapping
109
111
 
112
+ self._jax_mode = jax_mode
113
+ if self._jax_mode:
114
+ self._inversion_mapping = False
115
+
110
116
  if offset is None:
111
117
  offset = be.zeros(len(self._shape), be._int_dtype)
112
118
  self._offset = be.astype(be.to_backend_array(offset), int)
@@ -138,6 +144,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
138
144
  state: Tuple,
139
145
  scores: BackendArray,
140
146
  rotation_matrix: BackendArray,
147
+ **kwargs,
141
148
  ) -> Tuple:
142
149
  """
143
150
  Update the parameter store.
@@ -167,6 +174,8 @@ class MaxScoreOverRotations(AbstractAnalyzer):
167
174
  rotation_matrix = be.astype(rotation_matrix, be._float_dtype)
168
175
  if self._inversion_mapping:
169
176
  rotation_mapping[rotation_index] = rotation_matrix
177
+ elif self._jax_mode:
178
+ rotation_index = kwargs.get("rotation_index", 0)
170
179
  else:
171
180
  rotation = be.tobytes(rotation_matrix)
172
181
  rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
@@ -738,5 +747,5 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
738
747
  cls._invert_rmap(master_rotation_mapping),
739
748
  )
740
749
 
741
- def _postprocess(self, **kwargs):
742
- return self
750
+ def result(self, state: Tuple, **kwargs) -> Tuple:
751
+ return state