pytme 0.2.9.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. pytme-0.3b0.post1.data/scripts/estimate_memory_usage.py +76 -0
  2. pytme-0.3b0.post1.data/scripts/match_template.py +1098 -0
  3. {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/postprocess.py +318 -189
  4. {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/preprocess.py +21 -31
  5. {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/preprocessor_gui.py +12 -12
  6. pytme-0.3b0.post1.data/scripts/pytme_runner.py +769 -0
  7. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/METADATA +21 -20
  8. pytme-0.3b0.post1.dist-info/RECORD +126 -0
  9. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/entry_points.txt +2 -1
  10. pytme-0.3b0.post1.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +76 -0
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +224 -0
  14. scripts/match_template.py +341 -378
  15. pytme-0.2.9.post1.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
  16. scripts/postprocess.py +318 -189
  17. scripts/preprocess.py +21 -31
  18. scripts/preprocessor_gui.py +12 -12
  19. scripts/pytme_runner.py +769 -0
  20. scripts/refine_matches.py +625 -0
  21. tests/preprocessing/test_frequency_filters.py +28 -14
  22. tests/test_analyzer.py +41 -36
  23. tests/test_backends.py +1 -0
  24. tests/test_matching_cli.py +109 -54
  25. tests/test_matching_data.py +5 -5
  26. tests/test_matching_exhaustive.py +1 -2
  27. tests/test_matching_optimization.py +4 -9
  28. tests/test_matching_utils.py +1 -1
  29. tests/test_orientations.py +0 -1
  30. tme/__version__.py +1 -1
  31. tme/analyzer/__init__.py +2 -0
  32. tme/analyzer/_utils.py +26 -21
  33. tme/analyzer/aggregation.py +395 -222
  34. tme/analyzer/base.py +127 -0
  35. tme/analyzer/peaks.py +189 -204
  36. tme/analyzer/proxy.py +123 -0
  37. tme/backends/__init__.py +4 -3
  38. tme/backends/_cupy_utils.py +25 -24
  39. tme/backends/_jax_utils.py +20 -18
  40. tme/backends/cupy_backend.py +13 -26
  41. tme/backends/jax_backend.py +24 -23
  42. tme/backends/matching_backend.py +4 -3
  43. tme/backends/mlx_backend.py +4 -3
  44. tme/backends/npfftw_backend.py +34 -30
  45. tme/backends/pytorch_backend.py +18 -4
  46. tme/cli.py +126 -0
  47. tme/density.py +9 -7
  48. tme/filters/__init__.py +3 -3
  49. tme/filters/_utils.py +36 -10
  50. tme/filters/bandpass.py +229 -188
  51. tme/filters/compose.py +5 -4
  52. tme/filters/ctf.py +516 -254
  53. tme/filters/reconstruction.py +91 -32
  54. tme/filters/wedge.py +196 -135
  55. tme/filters/whitening.py +37 -42
  56. tme/matching_data.py +28 -39
  57. tme/matching_exhaustive.py +31 -27
  58. tme/matching_optimization.py +5 -4
  59. tme/matching_scores.py +25 -15
  60. tme/matching_utils.py +54 -9
  61. tme/memory.py +4 -3
  62. tme/orientations.py +22 -9
  63. tme/parser.py +114 -33
  64. tme/preprocessor.py +6 -5
  65. tme/rotations.py +10 -7
  66. tme/structure.py +4 -3
  67. pytme-0.2.9.post1.data/scripts/estimate_ram_usage.py +0 -97
  68. pytme-0.2.9.post1.dist-info/RECORD +0 -119
  69. pytme-0.2.9.post1.dist-info/licenses/LICENSE +0 -153
  70. scripts/estimate_ram_usage.py +0 -97
  71. tests/data/Maps/.DS_Store +0 -0
  72. tests/data/Structures/.DS_Store +0 -0
  73. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/WHEEL +0 -0
  74. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/top_level.txt +0 -0
@@ -1,34 +1,34 @@
1
- """ Implements classes to analyze outputs from exhaustive template matching.
1
+ """
2
+ Implements classes to analyze outputs from exhaustive template matching.
2
3
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
- from contextlib import nullcontext
9
- from multiprocessing import Manager
10
- from typing import Tuple, List, Dict, Generator
9
+ from typing import Tuple, List, Dict
11
10
 
12
11
  import numpy as np
13
12
 
13
+ from .base import AbstractAnalyzer
14
14
  from ..types import BackendArray
15
15
  from ._utils import cart_to_score
16
16
  from ..backends import backend as be
17
17
  from ..matching_utils import (
18
18
  create_mask,
19
19
  array_to_memmap,
20
- generate_tempfile_name,
21
20
  apply_convolution_mode,
21
+ generate_tempfile_name,
22
22
  )
23
23
 
24
-
25
24
  __all__ = [
26
25
  "MaxScoreOverRotations",
26
+ "MaxScoreOverRotationsConstrained",
27
27
  "MaxScoreOverTranslations",
28
28
  ]
29
29
 
30
30
 
31
- class MaxScoreOverRotations:
31
+ class MaxScoreOverRotations(AbstractAnalyzer):
32
32
  """
33
33
  Determine the rotation maximizing the score over all possible translations.
34
34
 
@@ -36,10 +36,6 @@ class MaxScoreOverRotations:
36
36
  ----------
37
37
  shape : tuple of int
38
38
  Shape of array passed to :py:meth:`MaxScoreOverRotations.__call__`.
39
- scores : BackendArray, optional
40
- Array mapping translations to scores.
41
- rotations : BackendArray, optional
42
- Array mapping translations to rotation indices.
43
39
  offset : BackendArray, optional
44
40
  Coordinate origin considered during merging, zero by default.
45
41
  score_threshold : float, optional
@@ -50,58 +46,47 @@ class MaxScoreOverRotations:
50
46
  Memmap internal arrays, False by default.
51
47
  thread_safe: bool, optional
52
48
  Allow class to be modified by multiple processes, True by default.
53
- only_unique_rotations : bool, optional
54
- Whether each rotation will be shown only once, False by default.
55
-
56
- Attributes
57
- ----------
58
- scores : BackendArray
59
- Mapping of translations to scores.
60
- rotations : BackendArray
61
- Mmapping of translations to rotation indices.
62
- rotation_mapping : Dict
63
- Mapping of rotations to rotation indices.
64
- offset : BackendArray, optional
65
- Coordinate origin considered during merging, zero by default
49
+ inversion_mapping : bool, optional
50
+ Do not use rotation matrix bytestrings for intermediate data handling.
51
+ This is useful for GPU backend where analyzers are not shared across
52
+ devices and every rotation is only observed once. It is generally
53
+ safe to deactivate inversion mapping, but at a cost of performance.
66
54
 
67
55
  Examples
68
56
  --------
69
57
  The following achieves the minimal definition of a :py:class:`MaxScoreOverRotations`
70
58
  instance
71
59
 
60
+ >>> import numpy as np
72
61
  >>> from tme.analyzer import MaxScoreOverRotations
73
- >>> analyzer = MaxScoreOverRotations(shape = (50, 50))
62
+ >>> analyzer = MaxScoreOverRotations(shape=(50, 50))
74
63
 
75
64
  The following simulates a template matching run by creating random data for a range
76
65
  of rotations and sending it to ``analyzer`` via its __call__ method
77
66
 
67
+ >>> state = analyzer.init_state()
78
68
  >>> for rotation_number in range(10):
79
69
  >>> scores = np.random.rand(50,50)
80
70
  >>> rotation = np.random.rand(scores.ndim, scores.ndim)
81
- >>> analyzer(scores = scores, rotation_matrix = rotation)
71
+ >>> state = analyzer(state, scores=scores, rotation_matrix=rotation)
82
72
 
83
- The aggregated scores can be extracted by invoking the __iter__ method of
73
+ The aggregated scores can be extracted by invoking the result method of
84
74
  ``analyzer``
85
75
 
86
- >>> results = tuple(analyzer)
76
+ >>> results = analyzer.result(state)
87
77
 
88
78
  The ``results`` tuple contains (1) the maximum scores for each translation,
89
79
  (2) an offset which is relevant when merging results from split template matching
90
80
  using :py:meth:`MaxScoreOverRotations.merge`, (3) the rotation used to obtain a
91
- score for a given translation, (4) a dictionary mapping rotation matrices to the
92
- indices used in (2).
81
+ score for a given translation, (4) a dictionary mapping indices used in (2) to
82
+ rotation matrices (2).
93
83
 
94
84
  We can extract the ``optimal_score``, ``optimal_translation`` and ``optimal_rotation``
95
85
  as follows
96
86
 
97
87
  >>> optimal_score = results[0].max()
98
88
  >>> optimal_translation = np.where(results[0] == results[0].max())
99
- >>> optimal_rotation_index = results[2][optimal_translation]
100
- >>> for key, value in results[3].items():
101
- >>> if value != optimal_rotation_index:
102
- >>> continue
103
- >>> optimal_rotation = np.frombuffer(key, rotation.dtype)
104
- >>> optimal_rotation = optimal_rotation.reshape(scores.ndim, scores.ndim)
89
+ >>> optimal_rotation = results[2][optimal_translation]
105
90
 
106
91
  The outlined procedure is a trivial method to identify high scoring peaks.
107
92
  Alternatively, :py:class:`PeakCaller` offers a range of more elaborate approaches
@@ -111,156 +96,221 @@ class MaxScoreOverRotations:
111
96
  def __init__(
112
97
  self,
113
98
  shape: Tuple[int],
114
- scores: BackendArray = None,
115
- rotations: BackendArray = None,
116
99
  offset: BackendArray = None,
117
100
  score_threshold: float = 0,
118
101
  shm_handler: object = None,
119
102
  use_memmap: bool = False,
120
- thread_safe: bool = True,
121
- only_unique_rotations: bool = False,
103
+ inversion_mapping: bool = False,
104
+ jax_mode: bool = False,
122
105
  **kwargs,
123
106
  ):
107
+ self._use_memmap = use_memmap
108
+ self._score_threshold = score_threshold
124
109
  self._shape = tuple(int(x) for x in shape)
110
+ self._inversion_mapping = inversion_mapping
125
111
 
126
- self.scores = scores
127
- if self.scores is None:
128
- self.scores = be.full(
129
- shape=self._shape, dtype=be._float_dtype, fill_value=score_threshold
130
- )
131
- self.rotations = rotations
132
- if self.rotations is None:
133
- self.rotations = be.full(self._shape, dtype=be._int_dtype, fill_value=-1)
134
-
135
- self.scores = be.to_sharedarr(self.scores, shm_handler)
136
- self.rotations = be.to_sharedarr(self.rotations, shm_handler)
112
+ self._jax_mode = jax_mode
113
+ if self._jax_mode:
114
+ self._inversion_mapping = False
137
115
 
138
116
  if offset is None:
139
117
  offset = be.zeros(len(self._shape), be._int_dtype)
140
- self.offset = be.astype(be.to_backend_array(offset), int)
118
+ self._offset = be.astype(be.to_backend_array(offset), int)
141
119
 
142
- self._use_memmap = use_memmap
143
- self._lock = Manager().Lock() if thread_safe else nullcontext()
144
- self._lock_is_nullcontext = isinstance(self.scores, type(be.zeros((1))))
145
- self._inversion_mapping = self._lock_is_nullcontext and only_unique_rotations
146
- self.rotation_mapping = Manager().dict() if thread_safe else {}
120
+ @property
121
+ def shareable(self):
122
+ return True
123
+
124
+ def init_state(self):
125
+ """
126
+ Initialize the analysis state.
127
+
128
+ Returns
129
+ -------
130
+ tuple
131
+ Initial state tuple containing (scores, rotations, rotation_mapping) where:
132
+ - scores : BackendArray of shape `self._shape` filled with `score_threshold`.
133
+ - rotations : BackendArray of shape `self._shape` filled with -1.
134
+ - rotation_mapping : dict, empty mapping from rotation bytes to indices.
135
+ """
136
+ scores = be.full(
137
+ shape=self._shape, dtype=be._float_dtype, fill_value=self._score_threshold
138
+ )
139
+ rotations = be.full(self._shape, dtype=be._int_dtype, fill_value=-1)
140
+ return scores, rotations, {}
141
+
142
+ def __call__(
143
+ self,
144
+ state: Tuple,
145
+ scores: BackendArray,
146
+ rotation_matrix: BackendArray,
147
+ **kwargs,
148
+ ) -> Tuple:
149
+ """
150
+ Update the parameter store.
151
+
152
+ Parameters
153
+ ----------
154
+ state : tuple
155
+ Current state tuple (scores, rotations, rotation_mapping) where:
156
+ - scores : BackendArray, current maximum scores.
157
+ - rotations : BackendArray, current rotation indices.
158
+ - rotation_mapping : dict, mapping from rotation bytes to indices.
159
+ scores : BackendArray
160
+ Array of new scores to update analyzer with.
161
+ rotation_matrix : BackendArray
162
+ Square matrix used to obtain the current rotation.
163
+ Returns
164
+ -------
165
+ tuple
166
+ Updated state tuple (scores, rotations, rotation_mapping).
167
+ """
168
+ # be.tobytes behaviour caused overhead for certain GPU/CUDA combinations
169
+ # If the analyzer is not shared and each rotation is unique, we can
170
+ # use index to rotation mapping and invert prior to merging.
171
+ prev_scores, rotations, rotation_mapping = state
172
+
173
+ rotation_index = len(rotation_mapping)
174
+ rotation_matrix = be.astype(rotation_matrix, be._float_dtype)
175
+ if self._inversion_mapping:
176
+ rotation_mapping[rotation_index] = rotation_matrix
177
+ elif self._jax_mode:
178
+ rotation_index = kwargs.get("rotation_index", 0)
179
+ else:
180
+ rotation = be.tobytes(rotation_matrix)
181
+ rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
182
+
183
+ scores, rotations = be.max_score_over_rotations(
184
+ scores=scores,
185
+ max_scores=prev_scores,
186
+ rotations=rotations,
187
+ rotation_index=rotation_index,
188
+ )
189
+ return scores, rotations, rotation_mapping
147
190
 
148
- def _postprocess(
191
+ @staticmethod
192
+ def _invert_rmap(rotation_mapping: dict) -> dict:
193
+ """
194
+ Invert dictionary from rotation matrix bytestrings mapping to rotation
195
+ indices ro rotation indices mapping to rotation matrices.
196
+ """
197
+ new_map, ndim = {}, None
198
+ for k, v in rotation_mapping.items():
199
+ nbytes = be.datatype_bytes(be._float_dtype)
200
+ dtype = np.float32 if nbytes == 4 else np.float16
201
+ rmat = np.frombuffer(k, dtype=dtype)
202
+ if ndim is None:
203
+ ndim = int(np.sqrt(rmat.size))
204
+ new_map[v] = rmat.reshape(ndim, ndim)
205
+ return new_map
206
+
207
+ def result(
149
208
  self,
150
- targetshape: Tuple[int],
151
- templateshape: Tuple[int],
152
- convolution_shape: Tuple[int],
209
+ state,
210
+ targetshape: Tuple[int] = None,
211
+ templateshape: Tuple[int] = None,
212
+ convolution_shape: Tuple[int] = None,
153
213
  fourier_shift: Tuple[int] = None,
154
214
  convolution_mode: str = None,
155
- shm_handler=None,
156
215
  **kwargs,
157
- ) -> "MaxScoreOverRotations":
158
- """Correct padding to Fourier shape and convolution mode."""
159
- scores = be.from_sharedarr(self.scores)
160
- rotations = be.from_sharedarr(self.rotations)
216
+ ) -> Tuple:
217
+ """
218
+ Finalize the analysis result with optional postprocessing.
219
+
220
+ Parameters
221
+ ----------
222
+ state : tuple
223
+ Current state tuple (scores, rotations, rotation_mapping) where:
224
+ - scores : BackendArray, current maximum scores.
225
+ - rotations : BackendArray, current rotation indices.
226
+ - rotation_mapping : dict, mapping from rotation indices to matrices.
227
+ targetshape : Tuple[int], optional
228
+ Shape of the target for convolution mode correction.
229
+ templateshape : Tuple[int], optional
230
+ Shape of the template for convolution mode correction.
231
+ convolution_shape : Tuple[int], optional
232
+ Shape used for convolution.
233
+ fourier_shift : Tuple[int], optional.
234
+ Shift to apply for Fourier correction.
235
+ convolution_mode : str, optional
236
+ Convolution mode for padding correction.
237
+ **kwargs
238
+ Additional keyword arguments.
239
+
240
+ Returns
241
+ -------
242
+ tuple
243
+ Final result tuple (scores, offset, rotations, rotation_mapping).
244
+ """
245
+ scores, rotations, rotation_mapping = state
246
+
247
+ # Apply postprocessing if parameters are provided
161
248
  if fourier_shift is not None:
162
249
  axis = tuple(i for i in range(len(fourier_shift)))
163
250
  scores = be.roll(scores, shift=fourier_shift, axis=axis)
164
251
  rotations = be.roll(rotations, shift=fourier_shift, axis=axis)
165
252
 
166
- convargs = {
167
- "s1": targetshape,
168
- "s2": templateshape,
169
- "convolution_mode": convolution_mode,
170
- "convolution_shape": convolution_shape,
171
- }
172
253
  if convolution_mode is not None:
254
+ convargs = {
255
+ "s1": targetshape,
256
+ "s2": templateshape,
257
+ "convolution_mode": convolution_mode,
258
+ "convolution_shape": convolution_shape,
259
+ }
173
260
  scores = apply_convolution_mode(scores, **convargs)
174
261
  rotations = apply_convolution_mode(rotations, **convargs)
175
262
 
176
- self._shape, self.scores, self.rotations = scores.shape, scores, rotations
177
- if shm_handler is not None:
178
- self.scores = be.to_sharedarr(scores, shm_handler)
179
- self.rotations = be.to_sharedarr(rotations, shm_handler)
180
- return self
181
-
182
- def __iter__(self) -> Generator:
183
- scores = be.from_sharedarr(self.scores)
184
- rotations = be.from_sharedarr(self.rotations)
185
-
186
263
  scores = be.to_numpy_array(scores)
187
264
  rotations = be.to_numpy_array(rotations)
188
265
  if self._use_memmap:
189
266
  scores = array_to_memmap(scores)
190
267
  rotations = array_to_memmap(rotations)
191
- else:
192
- if type(self.scores) is not type(scores):
193
- # Copy to avoid invalidation by shared memory handler
194
- scores, rotations = scores.copy(), rotations.copy()
195
268
 
196
269
  if self._inversion_mapping:
197
- self.rotation_mapping = {
198
- be.tobytes(v): k for k, v in self.rotation_mapping.items()
199
- }
270
+ rotation_mapping = {be.tobytes(v): k for k, v in rotation_mapping.items()}
200
271
 
201
- param_store = (
272
+ return (
202
273
  scores,
203
- be.to_numpy_array(self.offset),
274
+ be.to_numpy_array(self._offset),
204
275
  rotations,
205
- dict(self.rotation_mapping),
276
+ self._invert_rmap(rotation_mapping),
206
277
  )
207
- yield from param_store
208
278
 
209
- def __call__(self, scores: BackendArray, rotation_matrix: BackendArray):
279
+ def _harmonize_states(states: List[Tuple]):
210
280
  """
211
- Update the parameter store.
212
-
213
- Parameters
214
- ----------
215
- scores : BackendArray
216
- Array of scores.
217
- rotation_matrix : BackendArray
218
- Square matrix describing the current rotation.
281
+ Create consistent reference frame for merging different analyzer
282
+ instances, w.r.t. to rotations and output shape from different
283
+ splits of the target.
219
284
  """
220
- # be.tobytes behaviour caused overhead for certain GPU/CUDA combinations
221
- # If the analyzer is not shared and each rotation is unique, we can
222
- # use index to rotation mapping and invert prior to merging.
223
- if self._lock_is_nullcontext:
224
- rotation_index = len(self.rotation_mapping)
225
- if self._inversion_mapping:
226
- self.rotation_mapping[rotation_index] = rotation_matrix
227
- else:
228
- rotation = be.tobytes(rotation_matrix)
229
- rotation_index = self.rotation_mapping.setdefault(
230
- rotation, rotation_index
231
- )
232
- self.scores, self.rotations = be.max_score_over_rotations(
233
- scores=scores,
234
- max_scores=self.scores,
235
- rotations=self.rotations,
236
- rotation_index=rotation_index,
237
- )
238
- return None
285
+ new_rotation_mapping, out_shape = {}, None
286
+ for i in range(len(states)):
287
+ if states[i] is None:
288
+ continue
239
289
 
240
- rotation = be.tobytes(rotation_matrix)
241
- with self._lock:
242
- rotation_index = self.rotation_mapping.setdefault(
243
- rotation, len(self.rotation_mapping)
244
- )
245
- internal_scores = be.from_sharedarr(self.scores)
246
- internal_rotations = be.from_sharedarr(self.rotations)
247
- internal_sores, internal_rotations = be.max_score_over_rotations(
248
- scores=scores,
249
- max_scores=internal_scores,
250
- rotations=internal_rotations,
251
- rotation_index=rotation_index,
252
- )
253
- return None
290
+ scores, offset, rotations, rotation_mapping = states[i]
291
+ if out_shape is None:
292
+ out_shape = np.zeros(scores.ndim, int)
293
+ out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
294
+
295
+ new_param = {}
296
+ for key, value in rotation_mapping.items():
297
+ rotation_bytes = be.tobytes(value)
298
+ new_param[rotation_bytes] = key
299
+ if rotation_bytes not in new_rotation_mapping:
300
+ new_rotation_mapping[rotation_bytes] = len(new_rotation_mapping)
301
+ states[i] = (scores, offset, rotations, new_param)
302
+ out_shape = tuple(int(x) for x in out_shape)
303
+ return new_rotation_mapping, out_shape, states
254
304
 
255
305
  @classmethod
256
- def merge(cls, param_stores: List[Tuple], **kwargs) -> Tuple:
306
+ def merge(cls, results: List[Tuple], **kwargs) -> Tuple:
257
307
  """
258
308
  Merge multiple instances of the current class.
259
309
 
260
310
  Parameters
261
311
  ----------
262
- param_stores : list of tuple
263
- List of instance's internal state created by applying `tuple(instance)`.
312
+ results : list of tuple
313
+ List of instance's internal state created by applying `result`.
264
314
  **kwargs : dict, optional
265
315
  Optional keyword arguments.
266
316
 
@@ -276,8 +326,8 @@ class MaxScoreOverRotations:
276
326
  Mapping between rotations and rotation indices.
277
327
  """
278
328
  use_memmap = kwargs.get("use_memmap", False)
279
- if len(param_stores) == 1:
280
- ret = param_stores[0]
329
+ if len(results) == 1:
330
+ ret = results[0]
281
331
  if use_memmap:
282
332
  scores, offset, rotations, rotation_mapping = ret
283
333
  scores = array_to_memmap(scores)
@@ -287,25 +337,12 @@ class MaxScoreOverRotations:
287
337
  return ret
288
338
 
289
339
  # Determine output array shape and create consistent rotation map
290
- new_rotation_mapping, out_shape = {}, None
291
- for i in range(len(param_stores)):
292
- if param_stores[i] is None:
293
- continue
294
-
295
- scores, offset, rotations, rotation_mapping = param_stores[i]
296
- if out_shape is None:
297
- out_shape = np.zeros(scores.ndim, int)
298
- scores_dtype, rotations_dtype = scores.dtype, rotations.dtype
299
- out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
300
-
301
- for key, value in rotation_mapping.items():
302
- if key not in new_rotation_mapping:
303
- new_rotation_mapping[key] = len(new_rotation_mapping)
304
-
340
+ master_rotation_mapping, out_shape, results = cls._harmonize_states(results)
305
341
  if out_shape is None:
306
342
  return None
307
343
 
308
- out_shape = tuple(int(x) for x in out_shape)
344
+ scores_dtype = results[0][0].dtype
345
+ rotations_dtype = results[0][2].dtype
309
346
  if use_memmap:
310
347
  scores_out_filename = generate_tempfile_name()
311
348
  rotations_out_filename = generate_tempfile_name()
@@ -331,8 +368,8 @@ class MaxScoreOverRotations:
331
368
  )
332
369
  rotations_out = np.full(out_shape, fill_value=-1, dtype=rotations_dtype)
333
370
 
334
- for i in range(len(param_stores)):
335
- if param_stores[i] is None:
371
+ for i in range(len(results)):
372
+ if results[i] is None:
336
373
  continue
337
374
 
338
375
  if use_memmap:
@@ -348,7 +385,7 @@ class MaxScoreOverRotations:
348
385
  shape=out_shape,
349
386
  dtype=rotations_dtype,
350
387
  )
351
- scores, offset, rotations, rotation_mapping = param_stores[i]
388
+ scores, offset, rotations, rotation_mapping = results[i]
352
389
  stops = np.add(offset, scores.shape).astype(int)
353
390
  indices = tuple(slice(*pos) for pos in zip(offset, stops))
354
391
 
@@ -359,7 +396,7 @@ class MaxScoreOverRotations:
359
396
  len(rotation_mapping) + 1, dtype=rotations_out.dtype
360
397
  )
361
398
  for key, value in rotation_mapping.items():
362
- lookup_table[value] = new_rotation_mapping[key]
399
+ lookup_table[value] = master_rotation_mapping[key]
363
400
 
364
401
  updated_rotations = rotations[indices_update]
365
402
  if len(updated_rotations):
@@ -372,7 +409,7 @@ class MaxScoreOverRotations:
372
409
  rotations_out.flush()
373
410
  scores_out, rotations_out = None, None
374
411
 
375
- param_stores[i] = None
412
+ results[i] = None
376
413
  scores, rotations = None, None
377
414
 
378
415
  if use_memmap:
@@ -390,13 +427,165 @@ class MaxScoreOverRotations:
390
427
  scores_out,
391
428
  np.zeros(scores_out.ndim, dtype=int),
392
429
  rotations_out,
393
- new_rotation_mapping,
430
+ cls._invert_rmap(master_rotation_mapping),
394
431
  )
395
432
 
396
- @property
397
- def is_shareable(self) -> bool:
398
- """Boolean indicating whether class instance can be shared across processes."""
399
- return True
433
+
434
+ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
435
+ """
436
+ Implements constrained template matching using rejection sampling.
437
+
438
+ Parameters
439
+ ----------
440
+ cone_angle : float
441
+ Maximum accepted rotational deviation in degrees.
442
+ positions : BackendArray
443
+ Array of shape (n, d) with n seed point translations.
444
+ positions : BackendArray
445
+ Array of shape (n, d, d) with n seed point rotation matrices.
446
+ reference : BackendArray
447
+ Reference orientation of the template, wlog defaults to (0,0,1).
448
+ acceptance_radius : int or tuple of ints
449
+ Translational acceptance radius around seed point in voxels.
450
+ **kwargs : dict, optional
451
+ Keyword aguments passed to the constructor of :py:class:`MaxScoreOverRotations`.
452
+ """
453
+
454
+ def __init__(
455
+ self,
456
+ cone_angle: float,
457
+ positions: BackendArray,
458
+ rotations: BackendArray,
459
+ reference: BackendArray = (0, 0, 1),
460
+ acceptance_radius: int = 10,
461
+ **kwargs,
462
+ ):
463
+ MaxScoreOverRotations.__init__(self, **kwargs)
464
+
465
+ if not isinstance(acceptance_radius, (int, Tuple)):
466
+ raise ValueError("acceptance_radius needs to be of type int or tuple.")
467
+
468
+ if isinstance(acceptance_radius, int):
469
+ acceptance_radius = (
470
+ acceptance_radius,
471
+ acceptance_radius,
472
+ acceptance_radius,
473
+ )
474
+ acceptance_radius = tuple(int(x) for x in acceptance_radius)
475
+
476
+ self._cone_angle = float(np.radians(cone_angle))
477
+ self._cone_cutoff = float(np.tan(self._cone_angle))
478
+ self._reference = be.astype(
479
+ be.reshape(be.to_backend_array(reference), (-1,)), be._float_dtype
480
+ )
481
+ positions = be.astype(be.to_backend_array(positions), be._int_dtype)
482
+
483
+ ndim = positions.shape[1]
484
+ rotate_mask = len(set(acceptance_radius)) != 1
485
+ extend = max(acceptance_radius)
486
+ mask = create_mask(
487
+ mask_type="ellipse",
488
+ radius=acceptance_radius,
489
+ shape=tuple(2 * extend + 1 for _ in range(ndim)),
490
+ center=tuple(extend for _ in range(ndim)),
491
+ )
492
+ self._score_mask = be.astype(be.to_backend_array(mask), be._float_dtype)
493
+
494
+ # Map position from real space to shifted score space
495
+ lower_limit = be.to_backend_array(self._offset)
496
+ positions = be.subtract(positions, lower_limit)
497
+ positions, valid_positions = cart_to_score(
498
+ positions=positions,
499
+ fast_shape=kwargs.get("fast_shape", None),
500
+ targetshape=kwargs.get("targetshape", None),
501
+ templateshape=kwargs.get("templateshape", None),
502
+ fourier_shift=kwargs.get("fourier_shift", None),
503
+ convolution_mode=kwargs.get("convolution_mode", None),
504
+ convolution_shape=kwargs.get("convolution_shape", None),
505
+ )
506
+
507
+ self._positions = positions[valid_positions]
508
+ rotations = be.to_backend_array(rotations)[valid_positions]
509
+ ex = be.astype(be.to_backend_array((1, 0, 0)), be._float_dtype)
510
+ ey = be.astype(be.to_backend_array((0, 1, 0)), be._float_dtype)
511
+ ez = be.astype(be.to_backend_array((0, 0, 1)), be._float_dtype)
512
+
513
+ self._normals_x = (rotations @ ex[..., None])[..., 0]
514
+ self._normals_y = (rotations @ ey[..., None])[..., 0]
515
+ self._normals_z = (rotations @ ez[..., None])[..., 0]
516
+
517
+ # Periodic wrapping could be avoided by padding the target
518
+ shape = be.to_backend_array(self._shape)
519
+ starts = be.subtract(self._positions, extend)
520
+ ret, (n, d), mshape = [], self._positions.shape, self._score_mask.shape
521
+ if starts.shape[0] > 0:
522
+ for i in range(d):
523
+ indices = starts[:, slice(i, i + 1)] + be.arange(mshape[i])[None]
524
+ indices = be.mod(indices, shape[i], out=indices)
525
+ indices_shape = (n, *tuple(1 if k != i else -1 for k in range(d)))
526
+ ret.append(be.reshape(indices, indices_shape))
527
+
528
+ self._index_grid = tuple(ret)
529
+ self._mask_shape = tuple(1 if i != 0 else -1 for i in range(1 + ndim))
530
+
531
+ if rotate_mask:
532
+ self._score_mask = be.zeros(
533
+ (rotations.shape[0], *self._score_mask.shape), dtype=be._float_dtype
534
+ )
535
+ for i in range(rotations.shape[0]):
536
+ mask = create_mask(
537
+ mask_type="ellipse",
538
+ radius=acceptance_radius,
539
+ shape=tuple(2 * extend + 1 for _ in range(ndim)),
540
+ center=tuple(extend for _ in range(ndim)),
541
+ orientation=be.to_numpy_array(rotations[i]),
542
+ )
543
+ self._score_mask[i] = be.astype(
544
+ be.to_backend_array(mask), be._float_dtype
545
+ )
546
+
547
+ def __call__(
548
+ self, state: Tuple, scores: BackendArray, rotation_matrix: BackendArray
549
+ ) -> Tuple:
550
+ mask = self._get_constraint(rotation_matrix)
551
+ mask = self._get_score_mask(mask=mask, scores=scores)
552
+
553
+ scores = be.multiply(scores, mask, out=scores)
554
+ return super().__call__(state, scores=scores, rotation_matrix=rotation_matrix)
555
+
556
+ def _get_constraint(self, rotation_matrix: BackendArray) -> BackendArray:
557
+ """
558
+ Determine whether the angle between projection of reference w.r.t to
559
+ a given rotation matrix and a set of rotations fall within the set
560
+ cone_angle cutoff.
561
+
562
+ Parameters
563
+ ----------
564
+ rotation_matrix : BackendArray
565
+ Rotation matrix with shape (d,d).
566
+
567
+ Returns
568
+ -------
569
+ BackerndArray
570
+ Boolean mask of shape (n, )
571
+ """
572
+ template_rot = rotation_matrix @ self._reference
573
+
574
+ x = be.sum(be.multiply(self._normals_x, template_rot), axis=1)
575
+ y = be.sum(be.multiply(self._normals_y, template_rot), axis=1)
576
+ z = be.sum(be.multiply(self._normals_z, template_rot), axis=1)
577
+
578
+ return be.sqrt(x**2 + y**2) <= (z * self._cone_cutoff)
579
+
580
+ def _get_score_mask(self, mask: BackendArray, scores: BackendArray, **kwargs):
581
+ score_mask = be.zeros(scores.shape, scores.dtype)
582
+
583
+ if be.sum(mask) == 0:
584
+ return score_mask
585
+ mask = be.reshape(mask, self._mask_shape)
586
+
587
+ score_mask = be.addat(score_mask, self._index_grid, self._score_mask * mask)
588
+ return score_mask > 0
400
589
 
401
590
 
402
591
  class MaxScoreOverTranslations(MaxScoreOverRotations):
@@ -437,43 +626,40 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
437
626
  super().__init__(
438
627
  shape=shape_reduced, shm_handler=shm_handler, offset=offset, **kwargs
439
628
  )
440
-
441
- self.rotations = be.full(1, dtype=be._int_dtype, fill_value=-1)
442
- self.rotations = be.to_sharedarr(self.rotations, shm_handler)
443
629
  self._aggregate_axis = aggregate_axis
444
630
 
445
- def __call__(self, scores: BackendArray, rotation_matrix: BackendArray):
446
- if self._lock_is_nullcontext:
447
- rotation_index = len(self.rotation_mapping)
448
- if self._inversion_mapping:
449
- self.rotation_mapping[rotation_index] = rotation_matrix
450
- else:
451
- rotation = be.tobytes(rotation_matrix)
452
- rotation_index = self.rotation_mapping.setdefault(
453
- rotation, rotation_index
454
- )
455
- max_score = be.max(scores, axis=self._aggregate_axis)
456
- self.scores[rotation_index] = max_score
457
- return None
631
+ def init_state(self):
632
+ scores = be.full(
633
+ shape=self._shape, dtype=be._float_dtype, fill_value=self._score_threshold
634
+ )
635
+ rotations = be.full(1, dtype=be._int_dtype, fill_value=-1)
636
+ return scores, rotations, {}
458
637
 
459
- rotation = be.tobytes(rotation_matrix)
460
- with self._lock:
461
- rotation_index = self.rotation_mapping.setdefault(
462
- rotation, len(self.rotation_mapping)
463
- )
464
- internal_scores = be.from_sharedarr(self.scores)
465
- max_score = be.max(scores, axis=self._aggregate_axis)
466
- internal_scores[rotation_index] = max_score
467
- return None
638
+ def __call__(
639
+ self, state, scores: BackendArray, rotation_matrix: BackendArray
640
+ ) -> Tuple:
641
+ prev_scores, rotations, rotation_mapping = state
642
+
643
+ rotation_index = len(rotation_mapping)
644
+ if self._inversion_mapping:
645
+ rotation_mapping[rotation_index] = rotation_matrix
646
+ else:
647
+ rotation = be.tobytes(rotation_matrix)
648
+ rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
649
+ max_score = be.max(scores, axis=self._aggregate_axis)
650
+
651
+ update = prev_scores[rotation_index]
652
+ update = be.maximum(max_score, update, out=update)
653
+ return prev_scores, rotations, rotation_mapping
468
654
 
469
655
  @classmethod
470
- def merge(cls, param_stores: List[Tuple], **kwargs) -> Tuple:
656
+ def merge(cls, states: List[Tuple], **kwargs) -> Tuple:
471
657
  """
472
658
  Merge multiple instances of the current class.
473
659
 
474
660
  Parameters
475
661
  ----------
476
- param_stores : list of tuple
662
+ states : list of tuple
477
663
  List of instance's internal state created by applying `tuple(instance)`.
478
664
  **kwargs : dict, optional
479
665
  Optional keyword arguments.
@@ -489,31 +675,18 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
489
675
  Dict
490
676
  Mapping between rotations and rotation indices.
491
677
  """
492
- if len(param_stores) == 1:
493
- return param_stores[0]
678
+ if len(states) == 1:
679
+ return states[0]
494
680
 
495
681
  # Determine output array shape and create consistent rotation map
496
- new_rotation_mapping, out_shape = {}, None
497
- for i in range(len(param_stores)):
498
- if param_stores[i] is None:
499
- continue
500
-
501
- scores, offset, rotations, rotation_mapping = param_stores[i]
502
- if out_shape is None:
503
- out_shape = np.zeros(scores.ndim, int)
504
- scores_dtype, rotations_out = scores.dtype, rotations
505
- out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
506
-
507
- for key, value in rotation_mapping.items():
508
- if key not in new_rotation_mapping:
509
- new_rotation_mapping[key] = len(new_rotation_mapping)
510
-
682
+ states, master_rotation_mapping, out_shape = cls._harmonize_states(states)
511
683
  if out_shape is None:
512
684
  return None
513
685
 
514
- out_shape[0] = len(new_rotation_mapping)
686
+ out_shape[0] = len(master_rotation_mapping)
515
687
  out_shape = tuple(int(x) for x in out_shape)
516
688
 
689
+ scores_dtype = states[0][0].dtype
517
690
  use_memmap = kwargs.get("use_memmap", False)
518
691
  if use_memmap:
519
692
  scores_out_filename = generate_tempfile_name()
@@ -529,8 +702,8 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
529
702
  dtype=scores_dtype,
530
703
  )
531
704
 
532
- for i in range(len(param_stores)):
533
- if param_stores[i] is None:
705
+ for i in range(len(states)):
706
+ if states[i] is None:
534
707
  continue
535
708
 
536
709
  if use_memmap:
@@ -540,11 +713,11 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
540
713
  shape=out_shape,
541
714
  dtype=scores_dtype,
542
715
  )
543
- scores, offset, rotations, rotation_mapping = param_stores[i]
716
+ scores, offset, rotations, rotation_mapping = states[i]
544
717
 
545
718
  outer_table = np.arange(len(rotation_mapping), dtype=int)
546
719
  lookup_table = np.array(
547
- [new_rotation_mapping[key] for key in rotation_mapping.keys()],
720
+ [master_rotation_mapping[key] for key in rotation_mapping.keys()],
548
721
  dtype=int,
549
722
  )
550
723
 
@@ -560,7 +733,7 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
560
733
  scores_out.flush()
561
734
  scores_out = None
562
735
 
563
- param_stores[i], scores = None, None
736
+ states[i], scores = None, None
564
737
 
565
738
  if use_memmap:
566
739
  scores_out = np.memmap(
@@ -570,9 +743,9 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
570
743
  return (
571
744
  scores_out,
572
745
  np.zeros(scores_out.ndim, dtype=int),
573
- rotations_out,
574
- new_rotation_mapping,
746
+ states[2],
747
+ cls._invert_rmap(master_rotation_mapping),
575
748
  )
576
749
 
577
- def _postprocess(self, **kwargs):
578
- return self
750
+ def result(self, state: Tuple, **kwargs) -> Tuple:
751
+ return state