pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/match_template.py +28 -39
  2. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/postprocess.py +35 -21
  3. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/preprocessor_gui.py +95 -24
  4. pytme-0.3.1.post1.data/scripts/pytme_runner.py +1223 -0
  5. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/METADATA +5 -7
  6. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/RECORD +55 -48
  7. scripts/extract_candidates.py +118 -99
  8. scripts/match_template.py +28 -39
  9. scripts/postprocess.py +35 -21
  10. scripts/preprocessor_gui.py +95 -24
  11. scripts/pytme_runner.py +644 -190
  12. scripts/refine_matches.py +156 -386
  13. tests/data/.DS_Store +0 -0
  14. tests/data/Blurring/.DS_Store +0 -0
  15. tests/data/Maps/.DS_Store +0 -0
  16. tests/data/Raw/.DS_Store +0 -0
  17. tests/data/Structures/.DS_Store +0 -0
  18. tests/preprocessing/test_utils.py +18 -0
  19. tests/test_analyzer.py +2 -3
  20. tests/test_backends.py +3 -9
  21. tests/test_density.py +0 -1
  22. tests/test_extensions.py +0 -1
  23. tests/test_matching_utils.py +10 -60
  24. tests/test_rotations.py +1 -1
  25. tme/__version__.py +1 -1
  26. tme/analyzer/_utils.py +4 -4
  27. tme/analyzer/aggregation.py +35 -15
  28. tme/analyzer/peaks.py +11 -10
  29. tme/backends/_jax_utils.py +26 -13
  30. tme/backends/_numpyfftw_utils.py +270 -0
  31. tme/backends/cupy_backend.py +16 -55
  32. tme/backends/jax_backend.py +76 -37
  33. tme/backends/matching_backend.py +17 -51
  34. tme/backends/mlx_backend.py +1 -27
  35. tme/backends/npfftw_backend.py +71 -65
  36. tme/backends/pytorch_backend.py +1 -26
  37. tme/density.py +2 -6
  38. tme/extensions.cpython-311-darwin.so +0 -0
  39. tme/filters/ctf.py +22 -21
  40. tme/filters/wedge.py +10 -7
  41. tme/mask.py +341 -0
  42. tme/matching_data.py +31 -19
  43. tme/matching_exhaustive.py +37 -47
  44. tme/matching_optimization.py +2 -1
  45. tme/matching_scores.py +229 -411
  46. tme/matching_utils.py +73 -422
  47. tme/memory.py +1 -1
  48. tme/orientations.py +13 -8
  49. tme/rotations.py +1 -1
  50. pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
  51. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/estimate_memory_usage.py +0 -0
  52. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/preprocess.py +0 -0
  53. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/WHEEL +0 -0
  54. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/entry_points.txt +0 -0
  55. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/licenses/LICENSE +0 -0
  56. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/top_level.txt +0 -0
tests/data/.DS_Store ADDED
Binary file
Binary file
Binary file
Binary file
Binary file
@@ -51,6 +51,24 @@ class TestPreprocessUtils:
51
51
  assert fgrid.shape == tuple(tilt_shape)
52
52
  assert fgrid.max() <= np.sqrt(1 / sampling_rate * len(shape))
53
53
 
54
+ @pytest.mark.parametrize("shape", ((15, 15, 15), (31, 31, 31), (64, 64, 64)))
55
+ @pytest.mark.parametrize("sampling_rate", (0.5, 1, 2))
56
+ @pytest.mark.parametrize("angle", (-5, 0, 5))
57
+ def test_freqgrid_comparison(self, shape, sampling_rate, angle):
58
+ grid = frequency_grid_at_angle(
59
+ shape=shape,
60
+ angle=angle,
61
+ sampling_rate=sampling_rate,
62
+ opening_axis=2,
63
+ tilt_axis=0,
64
+ )
65
+ grid2 = fftfreqn(
66
+ shape=shape[1:], sampling_rate=sampling_rate, compute_euclidean_norm=True
67
+ )
68
+
69
+ # These should be equal for cubical input shapes
70
+ assert np.allclose(grid, grid2)
71
+
54
72
  @pytest.mark.parametrize("n", [10, 100, 1000])
55
73
  @pytest.mark.parametrize("sampling_rate", range(1, 4))
56
74
  def test_fftfreqn(self, n, sampling_rate):
tests/test_analyzer.py CHANGED
@@ -165,7 +165,6 @@ class TestMaxScoreOverRotations:
165
165
  assert res[0].dtype == be._float_dtype
166
166
  assert res[1].size == self.data.ndim
167
167
  assert np.allclose(res[2].shape, self.data.shape)
168
- assert len(res) == 4
169
168
 
170
169
  @pytest.mark.parametrize("use_memmap", [False, True])
171
170
  @pytest.mark.parametrize("score_threshold", [0, 1e10, -1e10])
@@ -181,7 +180,7 @@ class TestMaxScoreOverRotations:
181
180
 
182
181
  data2 = self.data * 2
183
182
  score_analyzer(state, data2, rotation_matrix=self.rotation_matrix)
184
- scores, translation_offset, rotations, mapping = score_analyzer.result(state)
183
+ scores, offset, rotations, mapping, *_ = score_analyzer.result(state)
185
184
 
186
185
  assert np.all(scores >= score_threshold)
187
186
  max_scores = np.maximum(self.data, data2)
@@ -214,7 +213,7 @@ class TestMaxScoreOverRotations:
214
213
  ret = MaxScoreOverRotations.merge(
215
214
  results=states, use_memmap=use_memmap, score_threshold=score_threshold
216
215
  )
217
- scores, translation, rotations, mapping = ret
216
+ scores, translation, rotations, mapping, *_ = ret
218
217
  assert np.all(scores >= score_threshold)
219
218
  max_scores = np.maximum(self.data, data2)
220
219
  max_scores = np.maximum(max_scores, score_threshold)
tests/test_backends.py CHANGED
@@ -292,16 +292,10 @@ class TestBackends:
292
292
 
293
293
  @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
294
294
  @pytest.mark.parametrize("fast_shape", ((10, 15, 100), (55, 23, 17)))
295
- def test_build_fft(self, backend, fast_shape):
295
+ def test_fft(self, backend, fast_shape):
296
296
  _, fast_shape, fast_ft_shape = backend.compute_convolution_shapes(
297
297
  fast_shape, (1 for _ in range(len(fast_shape)))
298
298
  )
299
- rfftn, irfftn = backend.build_fft(
300
- fwd_shape=fast_shape,
301
- inv_shape=fast_ft_shape,
302
- real_dtype=backend._float_dtype,
303
- cmpl_dtype=backend._complex_dtype,
304
- )
305
299
  arr = np.random.rand(*fast_shape)
306
300
  out = np.zeros(fast_ft_shape)
307
301
 
@@ -310,11 +304,11 @@ class TestBackends:
310
304
  backend.to_backend_array(out), backend._complex_dtype
311
305
  )
312
306
 
313
- rfftn(
307
+ backend.rfftn(
314
308
  backend.astype(backend.to_backend_array(arr), backend._float_dtype),
315
309
  complex_arr,
316
310
  )
317
- irfftn(complex_arr, real_arr)
311
+ backend.irfftn(complex_arr, real_arr)
318
312
  assert np.allclose(arr, backend.to_numpy_array(real_arr), rtol=0.3)
319
313
 
320
314
  @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
tests/test_density.py CHANGED
@@ -98,7 +98,6 @@ class TestDensity:
98
98
  assert np.allclose(density.data, self.density.data)
99
99
  assert np.allclose(density.sampling_rate, self.density.sampling_rate)
100
100
  assert np.allclose(density.origin, self.density.origin)
101
- assert density.metadata == self.density.metadata
102
101
 
103
102
  def test_from_file_baseline(self):
104
103
  self.test_to_file(gzip=False)
tests/test_extensions.py CHANGED
@@ -53,7 +53,6 @@ class TestExtensions:
53
53
  @pytest.mark.parametrize("min_distance", [0, 5, 10])
54
54
  def test_find_candidate_indices(self, dimension, dtype, min_distance):
55
55
  coordinates = COORDINATES[dimension].astype(dtype)
56
- print(coordinates.shape)
57
56
 
58
57
  min_distance = np.array([min_distance]).astype(dtype)[0]
59
58
 
@@ -10,9 +10,6 @@ from tme.backends import backend as be
10
10
  from tme.memory import MATCHING_MEMORY_REGISTRY
11
11
  from tme.matching_utils import (
12
12
  compute_parallelization_schedule,
13
- elliptical_mask,
14
- box_mask,
15
- tube_mask,
16
13
  create_mask,
17
14
  scramble_phases,
18
15
  apply_convolution_mode,
@@ -50,73 +47,26 @@ class TestMatchingUtils:
50
47
  max_splits=256,
51
48
  )
52
49
 
53
- def test_create_mask(self):
50
+ @pytest.mark.parametrize("mask_type", ["ellipse", "box", "tube", "membrane"])
51
+ def test_create_mask(self, mask_type: str):
54
52
  create_mask(
55
- mask_type="ellipse",
53
+ mask_type=mask_type,
56
54
  shape=self.density.shape,
57
55
  radius=5,
58
56
  center=np.divide(self.density.shape, 2),
57
+ height=np.max(self.density.shape) // 2,
58
+ size=np.divide(self.density.shape, 2).astype(int),
59
+ thickness=2,
60
+ separation=2,
61
+ symmetry_axis=1,
62
+ inner_radius=5,
63
+ outer_radius=10,
59
64
  )
60
65
 
61
66
  def test_create_mask_error(self):
62
67
  with pytest.raises(ValueError):
63
68
  create_mask(mask_type=None)
64
69
 
65
- def test_elliptical_mask(self):
66
- elliptical_mask(
67
- shape=self.density.shape,
68
- radius=5,
69
- center=np.divide(self.density.shape, 2),
70
- )
71
-
72
- def test_box_mask(self):
73
- box_mask(
74
- shape=self.density.shape,
75
- height=[5, 10, 20],
76
- center=np.divide(self.density.shape, 2),
77
- )
78
-
79
- def test_tube_mask(self):
80
- tube_mask(
81
- shape=self.density.shape,
82
- outer_radius=10,
83
- inner_radius=5,
84
- height=5,
85
- base_center=np.divide(self.density.shape, 2),
86
- symmetry_axis=1,
87
- )
88
-
89
- def test_tube_mask_error(self):
90
- with pytest.raises(ValueError):
91
- tube_mask(
92
- shape=self.density.shape,
93
- outer_radius=5,
94
- inner_radius=10,
95
- height=5,
96
- base_center=np.divide(self.density.shape, 2),
97
- symmetry_axis=1,
98
- )
99
-
100
- with pytest.raises(ValueError):
101
- tube_mask(
102
- shape=self.density.shape,
103
- outer_radius=5,
104
- inner_radius=10,
105
- height=10 * np.max(self.density.shape),
106
- base_center=np.divide(self.density.shape, 2),
107
- symmetry_axis=1,
108
- )
109
-
110
- with pytest.raises(ValueError):
111
- tube_mask(
112
- shape=self.density.shape,
113
- outer_radius=5,
114
- inner_radius=10,
115
- height=10 * np.max(self.density.shape),
116
- base_center=np.divide(self.density.shape, 2),
117
- symmetry_axis=len(self.density.shape) + 1,
118
- )
119
-
120
70
  def test_scramble_phases(self):
121
71
  scramble_phases(arr=self.density.data, noise_proportion=0.5)
122
72
 
tests/test_rotations.py CHANGED
@@ -8,6 +8,7 @@ from tme import Density
8
8
  from scipy.spatial.transform import Rotation
9
9
  from scipy.signal import correlate
10
10
 
11
+ from tme.mask import elliptical_mask
11
12
  from tme.rotations import (
12
13
  euler_from_rotationmatrix,
13
14
  euler_to_rotationmatrix,
@@ -16,7 +17,6 @@ from tme.rotations import (
16
17
  get_rotation_matrices,
17
18
  )
18
19
  from tme.matching_utils import (
19
- elliptical_mask,
20
20
  split_shape,
21
21
  compute_full_convolution_index,
22
22
  )
tme/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.3.b0.post1"
1
+ __version__ = "0.3.1"
tme/analyzer/_utils.py CHANGED
@@ -93,22 +93,22 @@ def cart_to_score(
93
93
  templateshape = be.to_backend_array(templateshape)
94
94
  convolution_shape = be.to_backend_array(convolution_shape)
95
95
 
96
- # Compute removed padding
97
96
  output_shape = _convmode_to_shape(
98
97
  convolution_mode=convolution_mode,
99
98
  targetshape=targetshape,
100
99
  templateshape=templateshape,
101
100
  convolution_shape=convolution_shape,
102
101
  )
103
- valid_positions = be.multiply(positions >= 0, positions < output_shape)
104
- valid_positions = be.sum(valid_positions, axis=1) == positions.shape[1]
105
102
 
103
+ # Offset from padding the target
106
104
  starts = be.astype(
107
105
  be.divide(be.subtract(convolution_shape, output_shape), 2),
108
106
  be._int_dtype,
109
107
  )
110
-
111
108
  positions = be.add(positions, starts)
109
+
110
+ valid_positions = be.multiply(positions >= 0, positions < fast_shape)
111
+ valid_positions = be.sum(valid_positions, axis=1) == positions.shape[1]
112
112
  if fourier_shift is not None:
113
113
  fourier_shift = be.to_backend_array(fourier_shift)
114
114
  positions = be.subtract(positions, fourier_shift)
@@ -132,12 +132,14 @@ class MaxScoreOverRotations(AbstractAnalyzer):
132
132
  - scores : BackendArray of shape `self._shape` filled with `score_threshold`.
133
133
  - rotations : BackendArray of shape `self._shape` filled with -1.
134
134
  - rotation_mapping : dict, empty mapping from rotation bytes to indices.
135
+ - ssum : BackendArray, accumulator for sum of squared scores.
135
136
  """
136
137
  scores = be.full(
137
138
  shape=self._shape, dtype=be._float_dtype, fill_value=self._score_threshold
138
139
  )
139
140
  rotations = be.full(self._shape, dtype=be._int_dtype, fill_value=-1)
140
- return scores, rotations, {}
141
+ ssum = be.full((1), dtype=be._float_dtype, fill_value=0)
142
+ return scores, rotations, {}, ssum
141
143
 
142
144
  def __call__(
143
145
  self,
@@ -156,6 +158,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
156
158
  - scores : BackendArray, current maximum scores.
157
159
  - rotations : BackendArray, current rotation indices.
158
160
  - rotation_mapping : dict, mapping from rotation bytes to indices.
161
+ - ssum : BackendArray, accumulator for sum of squared scores.
159
162
  scores : BackendArray
160
163
  Array of new scores to update analyzer with.
161
164
  rotation_matrix : BackendArray
@@ -168,7 +171,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
168
171
  # be.tobytes behaviour caused overhead for certain GPU/CUDA combinations
169
172
  # If the analyzer is not shared and each rotation is unique, we can
170
173
  # use index to rotation mapping and invert prior to merging.
171
- prev_scores, rotations, rotation_mapping = state
174
+ prev_scores, rotations, rotation_mapping, ssum = state
172
175
 
173
176
  rotation_index = len(rotation_mapping)
174
177
  rotation_matrix = be.astype(rotation_matrix, be._float_dtype)
@@ -180,13 +183,14 @@ class MaxScoreOverRotations(AbstractAnalyzer):
180
183
  rotation = be.tobytes(rotation_matrix)
181
184
  rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
182
185
 
186
+ ssum = be.add(ssum, be.ssum(scores), out=ssum)
183
187
  scores, rotations = be.max_score_over_rotations(
184
188
  scores=scores,
185
189
  max_scores=prev_scores,
186
190
  rotations=rotations,
187
191
  rotation_index=rotation_index,
188
192
  )
189
- return scores, rotations, rotation_mapping
193
+ return scores, rotations, rotation_mapping, ssum
190
194
 
191
195
  @staticmethod
192
196
  def _invert_rmap(rotation_mapping: dict) -> dict:
@@ -224,6 +228,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
224
228
  - scores : BackendArray, current maximum scores.
225
229
  - rotations : BackendArray, current rotation indices.
226
230
  - rotation_mapping : dict, mapping from rotation indices to matrices.
231
+ - ssum : BackendArray, accumulator for sum of squared scores.
227
232
  targetshape : Tuple[int], optional
228
233
  Shape of the target for convolution mode correction.
229
234
  templateshape : Tuple[int], optional
@@ -240,9 +245,9 @@ class MaxScoreOverRotations(AbstractAnalyzer):
240
245
  Returns
241
246
  -------
242
247
  tuple
243
- Final result tuple (scores, offset, rotations, rotation_mapping).
248
+ Final result tuple (scores, offset, rotations, rotation_mapping, ssum).
244
249
  """
245
- scores, rotations, rotation_mapping = state
250
+ scores, rotations, rotation_mapping, ssum = state
246
251
 
247
252
  # Apply postprocessing if parameters are provided
248
253
  if fourier_shift is not None:
@@ -269,11 +274,13 @@ class MaxScoreOverRotations(AbstractAnalyzer):
269
274
  if self._inversion_mapping:
270
275
  rotation_mapping = {be.tobytes(v): k for k, v in rotation_mapping.items()}
271
276
 
277
+ n_rotations = max(len(rotation_mapping), 1)
272
278
  return (
273
279
  scores,
274
280
  be.to_numpy_array(self._offset),
275
281
  rotations,
276
282
  self._invert_rmap(rotation_mapping),
283
+ be.to_numpy_array(ssum) / (scores.size * n_rotations),
277
284
  )
278
285
 
279
286
  def _harmonize_states(states: List[Tuple]):
@@ -287,18 +294,18 @@ class MaxScoreOverRotations(AbstractAnalyzer):
287
294
  if states[i] is None:
288
295
  continue
289
296
 
290
- scores, offset, rotations, rotation_mapping = states[i]
297
+ scores, offset, rotations, rotation_mapping, ssum = states[i]
291
298
  if out_shape is None:
292
299
  out_shape = np.zeros(scores.ndim, int)
293
300
  out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
294
301
 
295
302
  new_param = {}
296
303
  for key, value in rotation_mapping.items():
297
- rotation_bytes = be.tobytes(value)
304
+ rotation_bytes = np.asarray(value).tobytes()
298
305
  new_param[rotation_bytes] = key
299
306
  if rotation_bytes not in new_rotation_mapping:
300
307
  new_rotation_mapping[rotation_bytes] = len(new_rotation_mapping)
301
- states[i] = (scores, offset, rotations, new_param)
308
+ states[i] = (scores, offset, rotations, new_param, ssum)
302
309
  out_shape = tuple(int(x) for x in out_shape)
303
310
  return new_rotation_mapping, out_shape, states
304
311
 
@@ -329,11 +336,10 @@ class MaxScoreOverRotations(AbstractAnalyzer):
329
336
  if len(results) == 1:
330
337
  ret = results[0]
331
338
  if use_memmap:
332
- scores, offset, rotations, rotation_mapping = ret
339
+ scores, offset, rotations, rotation_mapping, ssum = ret
333
340
  scores = array_to_memmap(scores)
334
341
  rotations = array_to_memmap(rotations)
335
- ret = (scores, offset, rotations, rotation_mapping)
336
-
342
+ ret = (scores, offset, rotations, rotation_mapping, ssum)
337
343
  return ret
338
344
 
339
345
  # Determine output array shape and create consistent rotation map
@@ -368,6 +374,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
368
374
  )
369
375
  rotations_out = np.full(out_shape, fill_value=-1, dtype=rotations_dtype)
370
376
 
377
+ total_ssum = 0
371
378
  for i in range(len(results)):
372
379
  if results[i] is None:
373
380
  continue
@@ -385,7 +392,9 @@ class MaxScoreOverRotations(AbstractAnalyzer):
385
392
  shape=out_shape,
386
393
  dtype=rotations_dtype,
387
394
  )
388
- scores, offset, rotations, rotation_mapping = results[i]
395
+ scores, offset, rotations, rotation_mapping, ssum = results[i]
396
+
397
+ total_ssum = np.add(total_ssum, ssum)
389
398
  stops = np.add(offset, scores.shape).astype(int)
390
399
  indices = tuple(slice(*pos) for pos in zip(offset, stops))
391
400
 
@@ -428,6 +437,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
428
437
  np.zeros(scores_out.ndim, dtype=int),
429
438
  rotations_out,
430
439
  cls._invert_rmap(master_rotation_mapping),
440
+ total_ssum / len(results),
431
441
  )
432
442
 
433
443
 
@@ -545,13 +555,19 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
545
555
  )
546
556
 
547
557
  def __call__(
548
- self, state: Tuple, scores: BackendArray, rotation_matrix: BackendArray
558
+ self,
559
+ state: Tuple,
560
+ scores: BackendArray,
561
+ rotation_matrix: BackendArray,
562
+ **kwargs,
549
563
  ) -> Tuple:
550
564
  mask = self._get_constraint(rotation_matrix)
551
565
  mask = self._get_score_mask(mask=mask, scores=scores)
552
566
 
553
567
  scores = be.multiply(scores, mask, out=scores)
554
- return super().__call__(state, scores=scores, rotation_matrix=rotation_matrix)
568
+ return super().__call__(
569
+ state, scores=scores, rotation_matrix=rotation_matrix, **kwargs
570
+ )
555
571
 
556
572
  def _get_constraint(self, rotation_matrix: BackendArray) -> BackendArray:
557
573
  """
@@ -636,7 +652,11 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
636
652
  return scores, rotations, {}
637
653
 
638
654
  def __call__(
639
- self, state, scores: BackendArray, rotation_matrix: BackendArray
655
+ self,
656
+ state,
657
+ scores: BackendArray,
658
+ rotation_matrix: BackendArray,
659
+ **kwargs,
640
660
  ) -> Tuple:
641
661
  prev_scores, rotations, rotation_mapping = state
642
662
 
tme/analyzer/peaks.py CHANGED
@@ -228,10 +228,15 @@ class PeakCaller(AbstractAnalyzer):
228
228
  translations = be.full(
229
229
  (self.num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
230
230
  )
231
+
232
+ rdim = len(self.shape)
233
+ if self.batch_dims:
234
+ rdim -= len(self.batch_dims)
235
+
231
236
  rotations = be.full(
232
- (self.num_peaks, ndim, ndim), fill_value=0, dtype=be._float_dtype
237
+ (self.num_peaks, rdim, rdim), fill_value=0, dtype=be._float_dtype
233
238
  )
234
- for i in range(ndim):
239
+ for i in range(rdim):
235
240
  rotations[:, i, i] = 1.0
236
241
 
237
242
  scores = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
@@ -750,8 +755,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
750
755
  Dictionary mapping values in rotations to Euler angles.
751
756
  By default None
752
757
  min_score : float
753
- Minimum score value to consider. If provided, superseeds limit given
754
- by :py:attr:`PeakCaller.num_peaks`.
758
+ Minimum score value to consider.
755
759
 
756
760
  Returns
757
761
  -------
@@ -774,10 +778,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
774
778
  mask = be.to_backend_array(mask)
775
779
  mask_buffer = be.zeros(mask.shape, dtype=mask.dtype)
776
780
 
777
- peak_limit = self.num_peaks
778
- if min_score is not None:
779
- peak_limit = be.size(scores)
780
- else:
781
+ if min_score is None:
781
782
  min_score = be.min(scores) - 1
782
783
 
783
784
  _scores = be.zeros(scores.shape, dtype=scores.dtype)
@@ -815,7 +816,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
815
816
  score_mask = mask_buffer[tmpl_slice] <= 0.1
816
817
 
817
818
  _scores[score_slice] = be.multiply(_scores[score_slice], score_mask)
818
- if len(peaks) >= peak_limit:
819
+ if len(peaks) >= self.num_peaks:
819
820
  break
820
821
 
821
822
  return be.to_backend_array(peaks), None
@@ -851,7 +852,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
851
852
 
852
853
  rotation = rotation_mapping[rotation_space[tuple(peak)]]
853
854
 
854
- # TODO: Newer versions of rotation mapping contain rotation matrices not angles
855
+ # Old versions of rotation mapping contained Euler angles
855
856
  if rotation.ndim != 2:
856
857
  rotation = be.to_backend_array(
857
858
  euler_to_rotationmatrix(be.to_numpy_array(rotation))
@@ -112,10 +112,15 @@ def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
112
112
  return arr
113
113
 
114
114
 
115
+ def _mask_scores(arr, mask):
116
+ return arr.at[:].multiply(mask)
117
+
118
+
115
119
  @partial(
116
120
  pmap,
117
- in_axes=(0,) + (None,) * 6,
118
- static_broadcasted_argnums=[6, 7],
121
+ in_axes=(0,) + (None,) * 7,
122
+ static_broadcasted_argnums=[7, 8, 9, 10],
123
+ axis_name="batch",
119
124
  )
120
125
  def scan(
121
126
  target: BackendArray,
@@ -124,11 +129,20 @@ def scan(
124
129
  rotations: BackendArray,
125
130
  template_filter: BackendArray,
126
131
  target_filter: BackendArray,
132
+ score_mask: BackendArray,
127
133
  fast_shape: Tuple[int],
128
134
  rotate_mask: bool,
135
+ analyzer_class: object,
136
+ analyzer_kwargs: Tuple[Tuple],
129
137
  ) -> Tuple[BackendArray, BackendArray]:
130
138
  eps = jnp.finfo(template.dtype).resolution
131
139
 
140
+ kwargs = lax.switch(
141
+ lax.axis_index("batch"),
142
+ [lambda: analyzer_kwargs[i] for i in range(len(analyzer_kwargs))],
143
+ )
144
+ analyzer = analyzer_class(**be._tuple_to_dict(kwargs))
145
+
132
146
  if hasattr(target_filter, "shape"):
133
147
  target = _apply_fourier_filter(target, target_filter)
134
148
 
@@ -150,8 +164,12 @@ def scan(
150
164
  if template_filter.shape != ():
151
165
  _template_filter_func = _apply_fourier_filter
152
166
 
167
+ _score_mask_func = _identity
168
+ if score_mask.shape != ():
169
+ _score_mask_func = _mask_scores
170
+
153
171
  def _sample_transform(ret, rotation_matrix):
154
- max_scores, rotations, index = ret
172
+ state, index = ret
155
173
  template_rot, template_mask_rot = be.rigid_transform(
156
174
  arr=template,
157
175
  arr_mask=template_mask,
@@ -176,15 +194,10 @@ def scan(
176
194
  n_observations=n_observations,
177
195
  eps=eps,
178
196
  )
179
- max_scores, rotations = be.max_score_over_rotations(
180
- scores, max_scores, rotations, index
181
- )
182
- return (max_scores, rotations, index + 1), None
197
+ scores = _score_mask_func(scores, score_mask)
183
198
 
184
- score_space = jnp.zeros(fast_shape)
185
- rotation_space = jnp.full(shape=fast_shape, dtype=jnp.int32, fill_value=-1)
186
- (score_space, rotation_space, _), _ = lax.scan(
187
- _sample_transform, (score_space, rotation_space, 0), rotations
188
- )
199
+ state = analyzer(state, scores, rotation_matrix, rotation_index=index)
200
+ return (state, index + 1), None
189
201
 
190
- return score_space, rotation_space
202
+ (state, _), _ = lax.scan(_sample_transform, (analyzer.init_state(), 0), rotations)
203
+ return state