pytme 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. pytme-0.3.2.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/match_template.py +213 -196
  3. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/postprocess.py +40 -78
  4. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/preprocess.py +4 -5
  5. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/preprocessor_gui.py +50 -103
  6. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/pytme_runner.py +46 -69
  7. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/METADATA +3 -2
  8. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/RECORD +68 -65
  9. scripts/estimate_ram_usage.py +97 -0
  10. scripts/match_template.py +213 -196
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +40 -78
  13. scripts/preprocess.py +4 -5
  14. scripts/preprocessor_gui.py +50 -103
  15. scripts/pytme_runner.py +46 -69
  16. scripts/refine_matches.py +5 -7
  17. tests/preprocessing/test_compose.py +31 -30
  18. tests/preprocessing/test_frequency_filters.py +17 -32
  19. tests/preprocessing/test_preprocessor.py +0 -19
  20. tests/preprocessing/test_utils.py +13 -1
  21. tests/test_analyzer.py +2 -10
  22. tests/test_backends.py +47 -18
  23. tests/test_density.py +72 -13
  24. tests/test_extensions.py +1 -0
  25. tests/test_matching_cli.py +23 -9
  26. tests/test_matching_exhaustive.py +5 -5
  27. tests/test_matching_utils.py +3 -3
  28. tests/test_rotations.py +13 -23
  29. tests/test_structure.py +1 -7
  30. tme/__version__.py +1 -1
  31. tme/analyzer/aggregation.py +47 -16
  32. tme/analyzer/base.py +34 -0
  33. tme/analyzer/peaks.py +26 -13
  34. tme/analyzer/proxy.py +14 -0
  35. tme/backends/_jax_utils.py +124 -71
  36. tme/backends/cupy_backend.py +6 -19
  37. tme/backends/jax_backend.py +110 -105
  38. tme/backends/matching_backend.py +0 -17
  39. tme/backends/mlx_backend.py +0 -29
  40. tme/backends/npfftw_backend.py +100 -97
  41. tme/backends/pytorch_backend.py +65 -78
  42. tme/cli.py +2 -2
  43. tme/density.py +102 -58
  44. tme/extensions.cpython-311-darwin.so +0 -0
  45. tme/filters/_utils.py +52 -24
  46. tme/filters/bandpass.py +99 -105
  47. tme/filters/compose.py +133 -39
  48. tme/filters/ctf.py +51 -102
  49. tme/filters/reconstruction.py +67 -122
  50. tme/filters/wedge.py +296 -325
  51. tme/filters/whitening.py +39 -75
  52. tme/mask.py +2 -2
  53. tme/matching_data.py +87 -15
  54. tme/matching_exhaustive.py +70 -120
  55. tme/matching_optimization.py +9 -63
  56. tme/matching_scores.py +261 -100
  57. tme/matching_utils.py +150 -91
  58. tme/memory.py +1 -0
  59. tme/orientations.py +28 -8
  60. tme/preprocessor.py +0 -239
  61. tme/rotations.py +102 -70
  62. tme/structure.py +601 -631
  63. tme/types.py +1 -0
  64. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/estimate_memory_usage.py +0 -0
  65. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/WHEEL +0 -0
  66. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/entry_points.txt +0 -0
  67. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/licenses/LICENSE +0 -0
  68. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/top_level.txt +0 -0
@@ -192,6 +192,15 @@ class MaxScoreOverRotations(AbstractAnalyzer):
192
192
  )
193
193
  return scores, rotations, rotation_mapping, ssum
194
194
 
195
+ def correct_background(self, state, mean=0, inv_std=1, **kwargs):
196
+ scores, rotations, rotation_mapping, ssum = state
197
+
198
+ scores = be.subtract(scores, mean, out=scores)
199
+ scores = be.multiply(scores, inv_std, out=scores)
200
+
201
+ scores = be.maximum(scores, self._score_threshold, out=scores)
202
+ return scores, rotations, rotation_mapping, ssum
203
+
195
204
  @staticmethod
196
205
  def _invert_rmap(rotation_mapping: dict) -> dict:
197
206
  """
@@ -201,7 +210,12 @@ class MaxScoreOverRotations(AbstractAnalyzer):
201
210
  new_map, ndim = {}, None
202
211
  for k, v in rotation_mapping.items():
203
212
  nbytes = be.datatype_bytes(be._float_dtype)
204
- dtype = np.float32 if nbytes == 4 else np.float16
213
+ if nbytes == 8:
214
+ dtype = np.float64
215
+ elif nbytes == 4:
216
+ dtype = np.float32
217
+ else:
218
+ np.float16
205
219
  rmat = np.frombuffer(k, dtype=dtype)
206
220
  if ndim is None:
207
221
  ndim = int(np.sqrt(rmat.size))
@@ -451,7 +465,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
451
465
  Maximum accepted rotational deviation in degrees.
452
466
  positions : BackendArray
453
467
  Array of shape (n, d) with n seed point translations.
454
- positions : BackendArray
468
+ rotations : BackendArray
455
469
  Array of shape (n, d, d) with n seed point rotation matrices.
456
470
  reference : BackendArray
457
471
  Reference orientation of the template, wlog defaults to (0,0,1).
@@ -489,6 +503,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
489
503
  be.reshape(be.to_backend_array(reference), (-1,)), be._float_dtype
490
504
  )
491
505
  positions = be.astype(be.to_backend_array(positions), be._int_dtype)
506
+ rotations = be.astype(be.to_backend_array(rotations), be._float_dtype)
492
507
 
493
508
  ndim = positions.shape[1]
494
509
  rotate_mask = len(set(acceptance_radius)) != 1
@@ -515,7 +530,13 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
515
530
  )
516
531
 
517
532
  self._positions = positions[valid_positions]
518
- rotations = be.to_backend_array(rotations)[valid_positions]
533
+ rotations = rotations[valid_positions]
534
+
535
+ # Convert to pull matrix to remain consistent with rotation convention
536
+ rotations = be.concatenate(
537
+ [rotations[i].T[None] for i in range(rotations.shape[0])]
538
+ )
539
+
519
540
  ex = be.astype(be.to_backend_array((1, 0, 0)), be._float_dtype)
520
541
  ey = be.astype(be.to_backend_array((0, 1, 0)), be._float_dtype)
521
542
  ez = be.astype(be.to_backend_array((0, 0, 1)), be._float_dtype)
@@ -524,6 +545,15 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
524
545
  self._normals_y = (rotations @ ey[..., None])[..., 0]
525
546
  self._normals_z = (rotations @ ez[..., None])[..., 0]
526
547
 
548
+ # All scores will be rejected in this case. We should think about a
549
+ # unified interface for checking analyzer validity to skip such runs
550
+ if self._positions.shape[0] == 0:
551
+
552
+ def _get_score_mask(*args, **kwargs):
553
+ return 0
554
+
555
+ self._get_score_mask = _get_score_mask
556
+
527
557
  # Periodic wrapping could be avoided by padding the target
528
558
  shape = be.to_backend_array(self._shape)
529
559
  starts = be.subtract(self._positions, extend)
@@ -539,9 +569,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
539
569
  self._mask_shape = tuple(1 if i != 0 else -1 for i in range(1 + ndim))
540
570
 
541
571
  if rotate_mask:
542
- self._score_mask = be.zeros(
543
- (rotations.shape[0], *self._score_mask.shape), dtype=be._float_dtype
544
- )
572
+ self._score_mask = []
545
573
  for i in range(rotations.shape[0]):
546
574
  mask = create_mask(
547
575
  mask_type="ellipse",
@@ -550,9 +578,10 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
550
578
  center=tuple(extend for _ in range(ndim)),
551
579
  orientation=be.to_numpy_array(rotations[i]),
552
580
  )
553
- self._score_mask[i] = be.astype(
554
- be.to_backend_array(mask), be._float_dtype
581
+ self._score_mask.append(
582
+ be.astype(be.to_backend_array(mask), be._float_dtype)[None]
555
583
  )
584
+ self._score_mask = be.concatenate(self._score_mask)
556
585
 
557
586
  def __call__(
558
587
  self,
@@ -573,7 +602,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
573
602
  """
574
603
  Determine whether the angle between projection of reference w.r.t to
575
604
  a given rotation matrix and a set of rotations fall within the set
576
- cone_angle cutoff.
605
+ cone angle cutoff.
577
606
 
578
607
  Parameters
579
608
  ----------
@@ -585,7 +614,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
585
614
  BackerndArray
586
615
  Boolean mask of shape (n, )
587
616
  """
588
- template_rot = rotation_matrix @ self._reference
617
+ template_rot = rotation_matrix.T @ self._reference
589
618
 
590
619
  x = be.sum(be.multiply(self._normals_x, template_rot), axis=1)
591
620
  y = be.sum(be.multiply(self._normals_y, template_rot), axis=1)
@@ -596,10 +625,9 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
596
625
  def _get_score_mask(self, mask: BackendArray, scores: BackendArray, **kwargs):
597
626
  score_mask = be.zeros(scores.shape, scores.dtype)
598
627
 
599
- if be.sum(mask) == 0:
600
- return score_mask
628
+ # The indexing could be improved to avoid expanding the mask to
629
+ # the number of seed points
601
630
  mask = be.reshape(mask, self._mask_shape)
602
-
603
631
  score_mask = be.addat(score_mask, self._index_grid, self._score_mask * mask)
604
632
  return score_mask > 0
605
633
 
@@ -663,13 +691,16 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
663
691
  rotation_index = len(rotation_mapping)
664
692
  if self._inversion_mapping:
665
693
  rotation_mapping[rotation_index] = rotation_matrix
694
+ elif self._jax_mode:
695
+ rotation_index = kwargs.get("rotation_index", 0)
666
696
  else:
667
697
  rotation = be.tobytes(rotation_matrix)
668
698
  rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
669
- max_score = be.max(scores, axis=self._aggregate_axis)
670
699
 
671
- update = prev_scores[rotation_index]
672
- update = be.maximum(max_score, update, out=update)
700
+ scores = be.max(scores, axis=self._aggregate_axis)
701
+ scores = be.maximum(scores, prev_scores[rotation_index])
702
+ prev_scores = be.at(prev_scores, rotation_index, scores)
703
+
673
704
  return prev_scores, rotations, rotation_mapping
674
705
 
675
706
  @classmethod
tme/analyzer/base.py CHANGED
@@ -73,6 +73,40 @@ class AbstractAnalyzer(ABC):
73
73
  Updated analyzer state incorporating the new data.
74
74
  """
75
75
 
76
+ @abstractmethod
77
+ def correct_background(self, state, mean=0, inv_std=1, **kwargs):
78
+ """
79
+ Applies flat-fielding correction to scores f as
80
+
81
+ .. math::
82
+
83
+ f' = (f - \\text{mean}) \\cdot \\text{inv_std},
84
+
85
+ transforming raw correlations to SNR-like scores.
86
+
87
+ Parameters
88
+ ----------
89
+ state : tuple
90
+ Current analyzer state as returned :py:meth:`AbstractAnalyzer.init_state`
91
+ or previous invocations of :py:meth:`AbstractAnalyzer.__call__`.
92
+ mean : BackendArray, optional
93
+ Background mean (or equivalent), defaults to 0.
94
+ inv_std : BackendArray, optional
95
+ Reciprocal background standard deviation (or equivalent), defaults to 1.
96
+
97
+ Notes
98
+ -----
99
+ This method should be called after all rotations have been processed
100
+ but before calling :py:meth:`result`. The correction helps distinguish genuine
101
+ template matches from systematic background artifacts that may arise from
102
+ template edges, interpolation artifacts, or structured noise in the target.
103
+
104
+ Returns
105
+ -------
106
+ tuple
107
+ Updated analyzer state incorporating the new data.
108
+ """
109
+
76
110
  @abstractmethod
77
111
  def result(self, state: Tuple, **kwargs) -> Tuple:
78
112
  """
tme/analyzer/peaks.py CHANGED
@@ -18,7 +18,6 @@ from .base import AbstractAnalyzer
18
18
  from ._utils import score_to_cart
19
19
  from ..backends import backend as be
20
20
  from ..types import BackendArray, NDArray
21
- from ..rotations import euler_to_rotationmatrix
22
21
  from ..matching_utils import split_shape, compute_extraction_box
23
22
 
24
23
  __all__ = [
@@ -182,6 +181,7 @@ class PeakCaller(AbstractAnalyzer):
182
181
  min_score: float = None,
183
182
  max_score: float = None,
184
183
  batch_dims: Tuple[int] = None,
184
+ projection_dims: Tuple[int] = None,
185
185
  shm_handler: object = None,
186
186
  **kwargs,
187
187
  ):
@@ -197,9 +197,13 @@ class PeakCaller(AbstractAnalyzer):
197
197
  self.min_distance = int(min_distance)
198
198
  self.min_boundary_distance = int(min_boundary_distance)
199
199
 
200
- self.batch_dims = batch_dims
200
+ self.batch_dims = ()
201
201
  if batch_dims is not None:
202
- self.batch_dims = tuple(int(x) for x in self.batch_dims)
202
+ self.batch_dims = tuple(int(x) for x in batch_dims)
203
+
204
+ self.projection_dims = ()
205
+ if projection_dims is not None:
206
+ self.projection_dims = tuple(int(x) for x in projection_dims)
203
207
 
204
208
  self.min_score, self.max_score = min_score, max_score
205
209
 
@@ -231,7 +235,7 @@ class PeakCaller(AbstractAnalyzer):
231
235
 
232
236
  rdim = len(self.shape)
233
237
  if self.batch_dims:
234
- rdim -= len(self.batch_dims)
238
+ rdim = rdim - len(self.batch_dims) + len(self.projection_dims)
235
239
 
236
240
  rotations = be.full(
237
241
  (self.num_peaks, rdim, rdim), fill_value=0, dtype=be._float_dtype
@@ -388,6 +392,20 @@ class PeakCaller(AbstractAnalyzer):
388
392
 
389
393
  return state
390
394
 
395
+ def correct_background(self, state, mean, inv_std=1, **kwargs):
396
+ arr_type = type(be.zeros((1,), be._float_dtype))
397
+ translations, rotations, scores, details = state
398
+
399
+ if isinstance(mean, arr_type):
400
+ mean = mean[tuple(be.astype(translations.T, int))]
401
+ scores = be.subtract(scores, mean, out=scores)
402
+
403
+ if isinstance(inv_std, arr_type):
404
+ inv_std = inv_std[tuple(be.astype(translations.T, int))]
405
+ scores = be.multiply(scores, inv_std, out=scores)
406
+
407
+ return translations, rotations, scores, details
408
+
391
409
  @classmethod
392
410
  def merge(cls, results=List[Tuple], **kwargs) -> Tuple:
393
411
  """
@@ -778,6 +796,9 @@ class PeakCallerRecursiveMasking(PeakCaller):
778
796
  mask = be.to_backend_array(mask)
779
797
  mask_buffer = be.zeros(mask.shape, dtype=mask.dtype)
780
798
 
799
+ if min_score is None:
800
+ min_score = self.min_score
801
+
781
802
  if min_score is None:
782
803
  min_score = be.min(scores) - 1
783
804
 
@@ -849,15 +870,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
849
870
  """
850
871
  if rotation_space is None or rotation_mapping is None:
851
872
  return rotation_matrix
852
-
853
- rotation = rotation_mapping[rotation_space[tuple(peak)]]
854
-
855
- # Old versions of rotation mapping contained Euler angles
856
- if rotation.ndim != 2:
857
- rotation = be.to_backend_array(
858
- euler_to_rotationmatrix(be.to_numpy_array(rotation))
859
- )
860
- return rotation
873
+ return rotation_mapping[rotation_space[tuple(peak)]]
861
874
 
862
875
 
863
876
  class PeakCallerScipy(PeakCaller):
tme/analyzer/proxy.py CHANGED
@@ -85,6 +85,16 @@ class StatelessSharedAnalyzerProxy:
85
85
  final_state = tuple(self._shared_to_object(x) for x in final_state)
86
86
  return self._analyzer.result(final_state, **kwargs)
87
87
 
88
+ def correct_background(self, state, *args, **kwargs):
89
+ if self._shared:
90
+ # Copy to not correct the internal score array across processes
91
+ backend_arr = type(be.zeros((1), dtype=be._float_dtype))
92
+ state = tuple(self._shared_to_object(x) for x in state)
93
+ state = tuple(
94
+ be.copy(x) if isinstance(x, backend_arr) else x for x in state
95
+ )
96
+ return self._analyzer.correct_background(state, *args, **kwargs)
97
+
88
98
  def merge(self, *args, **kwargs):
89
99
  return self._analyzer.merge(*args, **kwargs)
90
100
 
@@ -121,3 +131,7 @@ class SharedAnalyzerProxy(StatelessSharedAnalyzerProxy):
121
131
  def result(self, **kwargs):
122
132
  """Extract final result"""
123
133
  return super().result(self._state, **kwargs)
134
+
135
+ def correct_background(self, *args, **kwargs):
136
+ # We always assign to state as this operation can not be shared
137
+ self._state = super().correct_background(self._state, *args, **kwargs)
@@ -10,14 +10,14 @@ from typing import Tuple
10
10
  from functools import partial
11
11
 
12
12
  import jax.numpy as jnp
13
- from jax import pmap, lax, vmap
13
+ from jax import pmap, lax, jit
14
14
 
15
15
  from ..types import BackendArray
16
16
  from ..backends import backend as be
17
- from ..matching_utils import normalize_template as _normalize_template
17
+ from ..matching_utils import standardize, to_padded
18
18
 
19
19
 
20
- __all__ = ["scan"]
20
+ __all__ = ["scan", "setup_scan"]
21
21
 
22
22
 
23
23
  def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
@@ -62,15 +62,14 @@ def _flcSphere_scoring(
62
62
  Computes :py:meth:`tme.matching_scores.corr_scoring`.
63
63
  """
64
64
  correlation = _correlate(template=template, ft_target=ft_target)
65
- correlation = correlation.at[:].multiply(inv_denominator)
66
- return correlation
65
+ return correlation.at[:].multiply(inv_denominator)
67
66
 
68
67
 
69
68
  def _reciprocal_target_std(
70
69
  ft_target: BackendArray,
71
70
  ft_target2: BackendArray,
72
71
  template_mask: BackendArray,
73
- n_observations: float,
72
+ n_obs: float,
74
73
  eps: float,
75
74
  ) -> BackendArray:
76
75
  """
@@ -80,16 +79,16 @@ def _reciprocal_target_std(
80
79
  --------
81
80
  :py:meth:`tme.matching_scores.flc_scoring`.
82
81
  """
83
- ft_shape = template_mask.shape
84
- ft_template_mask = jnp.fft.rfftn(template_mask, s=ft_shape)
82
+ shape = template_mask.shape
83
+ ft_template_mask = jnp.fft.rfftn(template_mask, s=shape)
85
84
 
86
85
  # E(X^2)- E(X)^2
87
- exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask, s=ft_shape)
88
- exp_sq = exp_sq.at[:].divide(n_observations)
86
+ exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask, s=shape)
87
+ exp_sq = exp_sq.at[:].divide(n_obs)
89
88
 
90
89
  ft_template_mask = ft_template_mask.at[:].multiply(ft_target)
91
- sq_exp = jnp.fft.irfftn(ft_template_mask, s=ft_shape)
92
- sq_exp = sq_exp.at[:].divide(n_observations)
90
+ sq_exp = jnp.fft.irfftn(ft_template_mask, s=shape)
91
+ sq_exp = sq_exp.at[:].divide(n_obs)
93
92
  sq_exp = sq_exp.at[:].power(2)
94
93
 
95
94
  exp_sq = exp_sq.at[:].add(-sq_exp)
@@ -97,7 +96,7 @@ def _reciprocal_target_std(
97
96
  exp_sq = exp_sq.at[:].power(0.5)
98
97
 
99
98
  exp_sq = exp_sq.at[:].set(
100
- jnp.where(exp_sq <= eps, 0, jnp.reciprocal(exp_sq * n_observations))
99
+ jnp.where(exp_sq <= eps, 0, jnp.reciprocal(exp_sq * n_obs))
101
100
  )
102
101
  return exp_sq
103
102
 
@@ -108,20 +107,50 @@ def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> Backen
108
107
  return arr.at[:].set(jnp.fft.irfftn(arr_ft, s=arr.shape))
109
108
 
110
109
 
111
- def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
112
- return arr
110
+ def setup_scan(analyzer_kwargs, analyzer, fast_shape, rotate_mask, match_projection):
111
+ """Create separate scan function with initialized analyzer for each device"""
112
+ device_scans = [
113
+ partial(
114
+ scan,
115
+ fast_shape=fast_shape,
116
+ rotate_mask=rotate_mask,
117
+ analyzer=analyzer(**device_config),
118
+ )
119
+ for device_config in analyzer_kwargs
120
+ ]
113
121
 
122
+ @partial(
123
+ pmap,
124
+ in_axes=(0,) + (None,) * 7,
125
+ axis_name="batch",
126
+ )
127
+ def scan_combined(
128
+ target,
129
+ template,
130
+ template_mask,
131
+ rotations,
132
+ template_filter,
133
+ target_filter,
134
+ score_mask,
135
+ background_template,
136
+ ):
137
+ return lax.switch(
138
+ lax.axis_index("batch"),
139
+ device_scans,
140
+ target,
141
+ template,
142
+ template_mask,
143
+ rotations,
144
+ template_filter,
145
+ target_filter,
146
+ score_mask,
147
+ background_template,
148
+ )
114
149
 
115
- def _mask_scores(arr, mask):
116
- return arr.at[:].multiply(mask)
150
+ return scan_combined
117
151
 
118
152
 
119
- @partial(
120
- pmap,
121
- in_axes=(0,) + (None,) * 7,
122
- static_broadcasted_argnums=[7, 8, 9, 10],
123
- axis_name="batch",
124
- )
153
+ @partial(jit, static_argnums=(8, 9, 10))
125
154
  def scan(
126
155
  target: BackendArray,
127
156
  template: BackendArray,
@@ -130,74 +159,98 @@ def scan(
130
159
  template_filter: BackendArray,
131
160
  target_filter: BackendArray,
132
161
  score_mask: BackendArray,
162
+ background_template: BackendArray,
133
163
  fast_shape: Tuple[int],
134
164
  rotate_mask: bool,
135
- analyzer_class: object,
136
- analyzer_kwargs: Tuple[Tuple],
137
- ) -> Tuple[BackendArray, BackendArray]:
165
+ analyzer: object,
166
+ ) -> Tuple:
138
167
  eps = jnp.finfo(template.dtype).resolution
139
168
 
140
- kwargs = lax.switch(
141
- lax.axis_index("batch"),
142
- [lambda: analyzer_kwargs[i] for i in range(len(analyzer_kwargs))],
143
- )
144
- analyzer = analyzer_class(**be._tuple_to_dict(kwargs))
145
-
146
- if hasattr(target_filter, "shape"):
169
+ if target_filter.shape != ():
147
170
  target = _apply_fourier_filter(target, target_filter)
148
171
 
149
172
  ft_target = jnp.fft.rfftn(target, s=fast_shape)
150
173
  ft_target2 = jnp.fft.rfftn(jnp.square(target), s=fast_shape)
151
- inv_denominator, target, scoring_func = None, None, _flc_scoring
174
+ _n_obs, _inv_denominator, target = None, None, None
175
+
176
+ unpadded_slice = tuple(slice(0, x) for x in template.shape)
177
+ rot_buffer, mask_rot_buffer = jnp.zeros(fast_shape), jnp.zeros(fast_shape)
152
178
  if not rotate_mask:
153
- n_observations = jnp.sum(template_mask)
154
- inv_denominator = _reciprocal_target_std(
179
+ _n_obs = jnp.sum(template_mask)
180
+ _inv_denominator = _reciprocal_target_std(
155
181
  ft_target=ft_target,
156
182
  ft_target2=ft_target2,
157
- template_mask=be.topleft_pad(template_mask, fast_shape),
183
+ template_mask=to_padded(mask_rot_buffer, template_mask, unpadded_slice),
158
184
  eps=eps,
159
- n_observations=n_observations,
185
+ n_obs=_n_obs,
160
186
  )
161
- ft_target2, scoring_func = None, _flcSphere_scoring
187
+ ft_target2 = None
162
188
 
163
- _template_filter_func = _identity
164
- if template_filter.shape != ():
165
- _template_filter_func = _apply_fourier_filter
189
+ mask_scores = score_mask.shape != ()
190
+ filter_template = template_filter.shape != ()
191
+ bg_correction = background_template.shape != ()
192
+ bg_scores = jnp.zeros(fast_shape) if bg_correction else 0
166
193
 
167
- _score_mask_func = _identity
168
- if score_mask.shape != ():
169
- _score_mask_func = _mask_scores
194
+ _template_mask_rot = template_mask
195
+ template_indices = be._index_grid(template.shape)
196
+ center = be.divide(be.to_backend_array(template.shape) - 1, 2)
170
197
 
171
198
  def _sample_transform(ret, rotation_matrix):
172
- state, index = ret
173
- template_rot, template_mask_rot = be.rigid_transform(
174
- arr=template,
175
- arr_mask=template_mask,
176
- rotation_matrix=rotation_matrix,
177
- order=1, # thats all we get for now
199
+ matrix = be._build_transform_matrix(
200
+ rotation_matrix=rotation_matrix, center=center
178
201
  )
202
+ indices = be._transform_indices(template_indices, matrix)
203
+
204
+ template_rot = be._interpolate(template, indices, order=1)
205
+ n_obs, template_mask_rot = _n_obs, _template_mask_rot
206
+ if rotate_mask:
207
+ template_mask_rot = be._interpolate(template_mask, indices, order=1)
208
+ n_obs = jnp.sum(template_mask_rot)
209
+
210
+ if filter_template:
211
+ template_rot = _apply_fourier_filter(template_rot, template_filter)
212
+ template_rot = standardize(template_rot, template_mask_rot, n_obs)
213
+
214
+ rot_pad = to_padded(rot_buffer, template_rot, unpadded_slice)
215
+
216
+ inv_denominator = _inv_denominator
217
+ if rotate_mask:
218
+ mask_rot_pad = to_padded(mask_rot_buffer, template_mask_rot, unpadded_slice)
219
+ inv_denominator = _reciprocal_target_std(
220
+ ft_target=ft_target,
221
+ ft_target2=ft_target2,
222
+ template_mask=mask_rot_pad,
223
+ n_obs=n_obs,
224
+ eps=eps,
225
+ )
226
+
227
+ scores = _flcSphere_scoring(ft_target, rot_pad, inv_denominator)
228
+ if mask_scores:
229
+ scores = scores.at[:].multiply(score_mask)
230
+
231
+ state, bg_scores, index = ret
232
+ state = analyzer(state, scores, rotation_matrix, rotation_index=index)
179
233
 
180
- n_observations = jnp.sum(template_mask_rot)
181
- template_rot = _template_filter_func(template_rot, template_filter)
182
- template_rot = _normalize_template(
183
- template_rot, template_mask_rot, n_observations
184
- )
185
- rot_pad = be.topleft_pad(template_rot, fast_shape)
186
- mask_rot_pad = be.topleft_pad(template_mask_rot, fast_shape)
234
+ if bg_correction:
235
+ template_rot = be._interpolate(background_template, indices, order=1)
236
+ if filter_template:
237
+ template_rot = _apply_fourier_filter(template_rot, template_filter)
238
+ template_rot = standardize(template_rot, template_mask_rot, n_obs)
187
239
 
188
- scores = scoring_func(
189
- template=rot_pad,
190
- template_mask=mask_rot_pad,
191
- ft_target=ft_target,
192
- ft_target2=ft_target2,
193
- inv_denominator=inv_denominator,
194
- n_observations=n_observations,
195
- eps=eps,
196
- )
197
- scores = _score_mask_func(scores, score_mask)
240
+ rot_pad = to_padded(rot_buffer, template_rot, unpadded_slice)
241
+ scores = _flcSphere_scoring(ft_target, rot_pad, inv_denominator)
242
+ bg_scores = jnp.maximum(bg_scores, scores)
198
243
 
199
- state = analyzer(state, scores, rotation_matrix, rotation_index=index)
200
- return (state, index + 1), None
244
+ return (state, bg_scores, index + 1), None
245
+
246
+ (state, bg_scores, _), _ = lax.scan(
247
+ _sample_transform, (analyzer.init_state(), bg_scores, 0), rotations
248
+ )
249
+
250
+ if bg_correction:
251
+ if mask_scores:
252
+ bg_scores = bg_scores.at[:].multiply(score_mask)
253
+ bg_scores = bg_scores.at[:].add(-be.mean(bg_scores))
254
+ state = analyzer.correct_background(state, bg_scores)
201
255
 
202
- (state, _), _ = lax.scan(_sample_transform, (analyzer.init_state(), 0), rotations)
203
256
  return state
@@ -33,7 +33,6 @@ class CupyBackend(NumpyFFTWBackend):
33
33
  import cupy as cp
34
34
  import cupyx.scipy.fft as cufft
35
35
  from cupyx.scipy.ndimage import affine_transform, maximum_filter
36
- from ._cupy_utils import affine_transform_batch
37
36
 
38
37
  float_dtype = cp.float32 if float_dtype is None else float_dtype
39
38
  complex_dtype = cp.complex64 if complex_dtype is None else complex_dtype
@@ -51,7 +50,6 @@ class CupyBackend(NumpyFFTWBackend):
51
50
  self._cufft = cufft
52
51
  self.maximum_filter = maximum_filter
53
52
  self.affine_transform = affine_transform
54
- self.affine_transform_batch = affine_transform_batch
55
53
 
56
54
  itype = f"int{self.datatype_bytes(int_dtype) * 8}"
57
55
  ftype = f"float{self.datatype_bytes(float_dtype) * 8}"
@@ -157,8 +155,8 @@ class CupyBackend(NumpyFFTWBackend):
157
155
 
158
156
  from voltools import StaticVolume
159
157
 
160
- # Only keep template and potential corresponding mask in cache
161
- if len(TEXTURE_CACHE) >= 2:
158
+ # Only keep template, mask and noise template in cache
159
+ if len(TEXTURE_CACHE) >= 3:
162
160
  TEXTURE_CACHE.clear()
163
161
 
164
162
  interpolation = "filt_bspline"
@@ -174,7 +172,7 @@ class CupyBackend(NumpyFFTWBackend):
174
172
 
175
173
  return TEXTURE_CACHE[key]
176
174
 
177
- def _rigid_transform(
175
+ def _transform(
178
176
  self,
179
177
  data: CupyArray,
180
178
  matrix: CupyArray,
@@ -182,21 +180,10 @@ class CupyBackend(NumpyFFTWBackend):
182
180
  prefilter: bool,
183
181
  order: int,
184
182
  cache: bool = False,
185
- batched: bool = False,
186
- ) -> None:
183
+ ) -> CupyArray:
187
184
  out_slice = tuple(slice(0, stop) for stop in data.shape)
188
- if batched:
189
- self.affine_transform_batch(
190
- input=data,
191
- matrix=matrix,
192
- mode="constant",
193
- output=output[out_slice],
194
- order=order,
195
- prefilter=prefilter,
196
- )
197
- return None
198
185
 
199
- if data.ndim == 3 and cache and self.texture_available and not batched:
186
+ if data.ndim == 3 and cache and self.texture_available:
200
187
  # Device memory pool (should) come to rescue performance
201
188
  temp = self.zeros(data.shape, data.dtype)
202
189
  texture = self._get_texture(data, order=order, prefilter=prefilter)
@@ -204,7 +191,7 @@ class CupyBackend(NumpyFFTWBackend):
204
191
  output[out_slice] = temp
205
192
  return None
206
193
 
207
- self.affine_transform(
194
+ return self.affine_transform(
208
195
  input=data,
209
196
  matrix=matrix,
210
197
  mode="constant",