pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py → pytme-0.3b0.data/scripts/estimate_memory_usage.py +16 -33
  2. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/match_template.py +224 -223
  3. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/postprocess.py +283 -163
  4. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/preprocess.py +11 -8
  5. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/preprocessor_gui.py +10 -9
  6. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/METADATA +11 -9
  7. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/RECORD +61 -58
  8. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/entry_points.txt +1 -1
  9. scripts/{estimate_ram_usage.py → estimate_memory_usage.py} +16 -33
  10. scripts/extract_candidates.py +224 -0
  11. scripts/match_template.py +224 -223
  12. scripts/postprocess.py +283 -163
  13. scripts/preprocess.py +11 -8
  14. scripts/preprocessor_gui.py +10 -9
  15. scripts/refine_matches.py +626 -0
  16. tests/preprocessing/test_frequency_filters.py +9 -4
  17. tests/test_analyzer.py +143 -138
  18. tests/test_matching_cli.py +85 -29
  19. tests/test_matching_exhaustive.py +1 -2
  20. tests/test_matching_optimization.py +4 -9
  21. tests/test_orientations.py +0 -1
  22. tme/__version__.py +1 -1
  23. tme/analyzer/__init__.py +2 -0
  24. tme/analyzer/_utils.py +25 -17
  25. tme/analyzer/aggregation.py +385 -220
  26. tme/analyzer/base.py +138 -0
  27. tme/analyzer/peaks.py +150 -88
  28. tme/analyzer/proxy.py +122 -0
  29. tme/backends/__init__.py +4 -3
  30. tme/backends/_cupy_utils.py +25 -24
  31. tme/backends/_jax_utils.py +4 -3
  32. tme/backends/cupy_backend.py +4 -13
  33. tme/backends/jax_backend.py +6 -8
  34. tme/backends/matching_backend.py +4 -3
  35. tme/backends/mlx_backend.py +4 -3
  36. tme/backends/npfftw_backend.py +7 -5
  37. tme/backends/pytorch_backend.py +14 -4
  38. tme/cli.py +126 -0
  39. tme/density.py +4 -3
  40. tme/filters/__init__.py +1 -1
  41. tme/filters/_utils.py +4 -3
  42. tme/filters/bandpass.py +6 -4
  43. tme/filters/compose.py +5 -4
  44. tme/filters/ctf.py +426 -214
  45. tme/filters/reconstruction.py +58 -28
  46. tme/filters/wedge.py +139 -61
  47. tme/filters/whitening.py +36 -36
  48. tme/matching_data.py +4 -3
  49. tme/matching_exhaustive.py +17 -16
  50. tme/matching_optimization.py +5 -4
  51. tme/matching_scores.py +4 -3
  52. tme/matching_utils.py +6 -4
  53. tme/memory.py +4 -3
  54. tme/orientations.py +9 -6
  55. tme/parser.py +5 -4
  56. tme/preprocessor.py +4 -3
  57. tme/rotations.py +10 -7
  58. tme/structure.py +4 -3
  59. tests/data/Maps/.DS_Store +0 -0
  60. tests/data/Structures/.DS_Store +0 -0
  61. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/WHEEL +0 -0
  62. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/licenses/LICENSE +0 -0
  63. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/top_level.txt +0 -0
@@ -1,34 +1,34 @@
1
- """ Implements classes to analyze outputs from exhaustive template matching.
1
+ """
2
+ Implements classes to analyze outputs from exhaustive template matching.
2
3
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
- from contextlib import nullcontext
9
- from multiprocessing import Manager
10
- from typing import Tuple, List, Dict, Generator
9
+ from typing import Tuple, List, Dict
11
10
 
12
11
  import numpy as np
13
12
 
13
+ from .base import AbstractAnalyzer
14
14
  from ..types import BackendArray
15
15
  from ._utils import cart_to_score
16
16
  from ..backends import backend as be
17
17
  from ..matching_utils import (
18
18
  create_mask,
19
19
  array_to_memmap,
20
- generate_tempfile_name,
21
20
  apply_convolution_mode,
21
+ generate_tempfile_name,
22
22
  )
23
23
 
24
-
25
24
  __all__ = [
26
25
  "MaxScoreOverRotations",
26
+ "MaxScoreOverRotationsConstrained",
27
27
  "MaxScoreOverTranslations",
28
28
  ]
29
29
 
30
30
 
31
- class MaxScoreOverRotations:
31
+ class MaxScoreOverRotations(AbstractAnalyzer):
32
32
  """
33
33
  Determine the rotation maximizing the score over all possible translations.
34
34
 
@@ -36,10 +36,6 @@ class MaxScoreOverRotations:
36
36
  ----------
37
37
  shape : tuple of int
38
38
  Shape of array passed to :py:meth:`MaxScoreOverRotations.__call__`.
39
- scores : BackendArray, optional
40
- Array mapping translations to scores.
41
- rotations : BackendArray, optional
42
- Array mapping translations to rotation indices.
43
39
  offset : BackendArray, optional
44
40
  Coordinate origin considered during merging, zero by default.
45
41
  score_threshold : float, optional
@@ -50,19 +46,11 @@ class MaxScoreOverRotations:
50
46
  Memmap internal arrays, False by default.
51
47
  thread_safe: bool, optional
52
48
  Allow class to be modified by multiple processes, True by default.
53
- only_unique_rotations : bool, optional
54
- Whether each rotation will be shown only once, False by default.
55
-
56
- Attributes
57
- ----------
58
- scores : BackendArray
59
- Mapping of translations to scores.
60
- rotations : BackendArray
61
- Mmapping of translations to rotation indices.
62
- rotation_mapping : Dict
63
- Mapping of rotations to rotation indices.
64
- offset : BackendArray, optional
65
- Coordinate origin considered during merging, zero by default
49
+ inversion_mapping : bool, optional
50
+ Do not use rotation matrix bytestrings for intermediate data handling.
51
+ This is useful for GPU backend where analyzers are not shared across
52
+ devices and every rotation is only observed once. It is generally
53
+ safe to deactivate inversion mapping, but at a cost of performance.
66
54
 
67
55
  Examples
68
56
  --------
@@ -75,33 +63,29 @@ class MaxScoreOverRotations:
75
63
  The following simulates a template matching run by creating random data for a range
76
64
  of rotations and sending it to ``analyzer`` via its __call__ method
77
65
 
66
+ >> state = analyzer.init_state()
78
67
  >>> for rotation_number in range(10):
79
68
  >>> scores = np.random.rand(50,50)
80
69
  >>> rotation = np.random.rand(scores.ndim, scores.ndim)
81
- >>> analyzer(scores = scores, rotation_matrix = rotation)
70
+ >>> state, analyzer(state, scores = scores, rotation_matrix = rotation)
82
71
 
83
- The aggregated scores can be extracted by invoking the __iter__ method of
72
+ The aggregated scores can be extracted by invoking the result method of
84
73
  ``analyzer``
85
74
 
86
- >>> results = tuple(analyzer)
75
+ >>> results = analyzer.result()
87
76
 
88
77
  The ``results`` tuple contains (1) the maximum scores for each translation,
89
78
  (2) an offset which is relevant when merging results from split template matching
90
79
  using :py:meth:`MaxScoreOverRotations.merge`, (3) the rotation used to obtain a
91
- score for a given translation, (4) a dictionary mapping rotation matrices to the
92
- indices used in (2).
80
+ score for a given translation, (4) a dictionary mapping indices used in (2) to
81
+ rotation matrices (2).
93
82
 
94
83
  We can extract the ``optimal_score``, ``optimal_translation`` and ``optimal_rotation``
95
84
  as follows
96
85
 
97
86
  >>> optimal_score = results[0].max()
98
87
  >>> optimal_translation = np.where(results[0] == results[0].max())
99
- >>> optimal_rotation_index = results[2][optimal_translation]
100
- >>> for key, value in results[3].items():
101
- >>> if value != optimal_rotation_index:
102
- >>> continue
103
- >>> optimal_rotation = np.frombuffer(key, rotation.dtype)
104
- >>> optimal_rotation = optimal_rotation.reshape(scores.ndim, scores.ndim)
88
+ >>> optimal_rotation = results[2][optimal_translation]
105
89
 
106
90
  The outlined procedure is a trivial method to identify high scoring peaks.
107
91
  Alternatively, :py:class:`PeakCaller` offers a range of more elaborate approaches
@@ -111,156 +95,213 @@ class MaxScoreOverRotations:
111
95
  def __init__(
112
96
  self,
113
97
  shape: Tuple[int],
114
- scores: BackendArray = None,
115
- rotations: BackendArray = None,
116
98
  offset: BackendArray = None,
117
99
  score_threshold: float = 0,
118
100
  shm_handler: object = None,
119
101
  use_memmap: bool = False,
120
- thread_safe: bool = True,
121
- only_unique_rotations: bool = False,
102
+ inversion_mapping: bool = False,
122
103
  **kwargs,
123
104
  ):
105
+ self._use_memmap = use_memmap
106
+ self._score_threshold = score_threshold
124
107
  self._shape = tuple(int(x) for x in shape)
125
-
126
- self.scores = scores
127
- if self.scores is None:
128
- self.scores = be.full(
129
- shape=self._shape, dtype=be._float_dtype, fill_value=score_threshold
130
- )
131
- self.rotations = rotations
132
- if self.rotations is None:
133
- self.rotations = be.full(self._shape, dtype=be._int_dtype, fill_value=-1)
134
-
135
- self.scores = be.to_sharedarr(self.scores, shm_handler)
136
- self.rotations = be.to_sharedarr(self.rotations, shm_handler)
108
+ self._inversion_mapping = inversion_mapping
137
109
 
138
110
  if offset is None:
139
111
  offset = be.zeros(len(self._shape), be._int_dtype)
140
- self.offset = be.astype(be.to_backend_array(offset), int)
112
+ self._offset = be.astype(be.to_backend_array(offset), int)
141
113
 
142
- self._use_memmap = use_memmap
143
- self._lock = Manager().Lock() if thread_safe else nullcontext()
144
- self._lock_is_nullcontext = isinstance(self.scores, type(be.zeros((1))))
145
- self._inversion_mapping = self._lock_is_nullcontext and only_unique_rotations
146
- self.rotation_mapping = Manager().dict() if thread_safe else {}
114
+ @property
115
+ def shareable(self):
116
+ return True
117
+
118
+ def init_state(self):
119
+ """
120
+ Initialize the analysis state.
147
121
 
148
- def _postprocess(
122
+ Returns
123
+ -------
124
+ tuple
125
+ Initial state tuple containing (scores, rotations, rotation_mapping) where:
126
+ - scores : BackendArray of shape `self._shape` filled with `score_threshold`.
127
+ - rotations : BackendArray of shape `self._shape` filled with -1.
128
+ - rotation_mapping : dict, empty mapping from rotation bytes to indices.
129
+ """
130
+ scores = be.full(
131
+ shape=self._shape, dtype=be._float_dtype, fill_value=self._score_threshold
132
+ )
133
+ rotations = be.full(self._shape, dtype=be._int_dtype, fill_value=-1)
134
+ return scores, rotations, {}
135
+
136
+ def __call__(
149
137
  self,
150
- targetshape: Tuple[int],
151
- templateshape: Tuple[int],
152
- convolution_shape: Tuple[int],
138
+ state: Tuple,
139
+ scores: BackendArray,
140
+ rotation_matrix: BackendArray,
141
+ ) -> Tuple:
142
+ """
143
+ Update the parameter store.
144
+
145
+ Parameters
146
+ ----------
147
+ state : tuple
148
+ Current state tuple (scores, rotations, rotation_mapping) where:
149
+ - scores : BackendArray, current maximum scores.
150
+ - rotations : BackendArray, current rotation indices.
151
+ - rotation_mapping : dict, mapping from rotation bytes to indices.
152
+ scores : BackendArray
153
+ Array of new scores to update analyzer with.
154
+ rotation_matrix : BackendArray
155
+ Square matrix used to obtain the current rotation.
156
+ Returns
157
+ -------
158
+ tuple
159
+ Updated state tuple (scores, rotations, rotation_mapping).
160
+ """
161
+ # be.tobytes behaviour caused overhead for certain GPU/CUDA combinations
162
+ # If the analyzer is not shared and each rotation is unique, we can
163
+ # use index to rotation mapping and invert prior to merging.
164
+ prev_scores, rotations, rotation_mapping = state
165
+
166
+ rotation_index = len(rotation_mapping)
167
+ rotation_matrix = be.astype(rotation_matrix, be._float_dtype)
168
+ if self._inversion_mapping:
169
+ rotation_mapping[rotation_index] = rotation_matrix
170
+ else:
171
+ rotation = be.tobytes(rotation_matrix)
172
+ rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
173
+
174
+ scores, rotations = be.max_score_over_rotations(
175
+ scores=scores,
176
+ max_scores=prev_scores,
177
+ rotations=rotations,
178
+ rotation_index=rotation_index,
179
+ )
180
+ return scores, rotations, rotation_mapping
181
+
182
+ @staticmethod
183
+ def _invert_rmap(rotation_mapping: dict) -> dict:
184
+ """
185
+ Invert dictionary from rotation matrix bytestrings mapping to rotation
186
+ indices ro rotation indices mapping to rotation matrices.
187
+ """
188
+ new_map, ndim = {}, None
189
+ for k, v in rotation_mapping.items():
190
+ nbytes = be.datatype_bytes(be._float_dtype)
191
+ dtype = np.float32 if nbytes == 4 else np.float16
192
+ rmat = np.frombuffer(k, dtype=dtype)
193
+ if ndim is None:
194
+ ndim = int(np.sqrt(rmat.size))
195
+ new_map[v] = rmat.reshape(ndim, ndim)
196
+ return new_map
197
+
198
+ def result(
199
+ self,
200
+ state,
201
+ targetshape: Tuple[int] = None,
202
+ templateshape: Tuple[int] = None,
203
+ convolution_shape: Tuple[int] = None,
153
204
  fourier_shift: Tuple[int] = None,
154
205
  convolution_mode: str = None,
155
- shm_handler=None,
156
206
  **kwargs,
157
- ) -> "MaxScoreOverRotations":
158
- """Correct padding to Fourier shape and convolution mode."""
159
- scores = be.from_sharedarr(self.scores)
160
- rotations = be.from_sharedarr(self.rotations)
207
+ ) -> Tuple:
208
+ """
209
+ Finalize the analysis result with optional postprocessing.
210
+
211
+ Parameters
212
+ ----------
213
+ state : tuple
214
+ Current state tuple (scores, rotations, rotation_mapping) where:
215
+ - scores : BackendArray, current maximum scores.
216
+ - rotations : BackendArray, current rotation indices.
217
+ - rotation_mapping : dict, mapping from rotation indices to matrices.
218
+ targetshape : Tuple[int], optional
219
+ Shape of the target for convolution mode correction.
220
+ templateshape : Tuple[int], optional
221
+ Shape of the template for convolution mode correction.
222
+ convolution_shape : Tuple[int], optional
223
+ Shape used for convolution.
224
+ fourier_shift : Tuple[int], optional.
225
+ Shift to apply for Fourier correction.
226
+ convolution_mode : str, optional
227
+ Convolution mode for padding correction.
228
+ **kwargs
229
+ Additional keyword arguments.
230
+
231
+ Returns
232
+ -------
233
+ tuple
234
+ Final result tuple (scores, offset, rotations, rotation_mapping).
235
+ """
236
+ scores, rotations, rotation_mapping = state
237
+
238
+ # Apply postprocessing if parameters are provided
161
239
  if fourier_shift is not None:
162
240
  axis = tuple(i for i in range(len(fourier_shift)))
163
241
  scores = be.roll(scores, shift=fourier_shift, axis=axis)
164
242
  rotations = be.roll(rotations, shift=fourier_shift, axis=axis)
165
243
 
166
- convargs = {
167
- "s1": targetshape,
168
- "s2": templateshape,
169
- "convolution_mode": convolution_mode,
170
- "convolution_shape": convolution_shape,
171
- }
172
244
  if convolution_mode is not None:
245
+ convargs = {
246
+ "s1": targetshape,
247
+ "s2": templateshape,
248
+ "convolution_mode": convolution_mode,
249
+ "convolution_shape": convolution_shape,
250
+ }
173
251
  scores = apply_convolution_mode(scores, **convargs)
174
252
  rotations = apply_convolution_mode(rotations, **convargs)
175
253
 
176
- self._shape, self.scores, self.rotations = scores.shape, scores, rotations
177
- if shm_handler is not None:
178
- self.scores = be.to_sharedarr(scores, shm_handler)
179
- self.rotations = be.to_sharedarr(rotations, shm_handler)
180
- return self
181
-
182
- def __iter__(self) -> Generator:
183
- scores = be.from_sharedarr(self.scores)
184
- rotations = be.from_sharedarr(self.rotations)
185
-
186
254
  scores = be.to_numpy_array(scores)
187
255
  rotations = be.to_numpy_array(rotations)
188
256
  if self._use_memmap:
189
257
  scores = array_to_memmap(scores)
190
258
  rotations = array_to_memmap(rotations)
191
- else:
192
- if type(self.scores) is not type(scores):
193
- # Copy to avoid invalidation by shared memory handler
194
- scores, rotations = scores.copy(), rotations.copy()
195
259
 
196
260
  if self._inversion_mapping:
197
- self.rotation_mapping = {
198
- be.tobytes(v): k for k, v in self.rotation_mapping.items()
199
- }
261
+ rotation_mapping = {be.tobytes(v): k for k, v in rotation_mapping.items()}
200
262
 
201
- param_store = (
263
+ return (
202
264
  scores,
203
- be.to_numpy_array(self.offset),
265
+ be.to_numpy_array(self._offset),
204
266
  rotations,
205
- dict(self.rotation_mapping),
267
+ self._invert_rmap(rotation_mapping),
206
268
  )
207
- yield from param_store
208
269
 
209
- def __call__(self, scores: BackendArray, rotation_matrix: BackendArray):
270
+ def _harmonize_states(states: List[Tuple]):
210
271
  """
211
- Update the parameter store.
212
-
213
- Parameters
214
- ----------
215
- scores : BackendArray
216
- Array of scores.
217
- rotation_matrix : BackendArray
218
- Square matrix describing the current rotation.
272
+ Create consistent reference frame for merging different analyzer
273
+ instances, w.r.t. to rotations and output shape from different
274
+ splits of the target.
219
275
  """
220
- # be.tobytes behaviour caused overhead for certain GPU/CUDA combinations
221
- # If the analyzer is not shared and each rotation is unique, we can
222
- # use index to rotation mapping and invert prior to merging.
223
- if self._lock_is_nullcontext:
224
- rotation_index = len(self.rotation_mapping)
225
- if self._inversion_mapping:
226
- self.rotation_mapping[rotation_index] = rotation_matrix
227
- else:
228
- rotation = be.tobytes(rotation_matrix)
229
- rotation_index = self.rotation_mapping.setdefault(
230
- rotation, rotation_index
231
- )
232
- self.scores, self.rotations = be.max_score_over_rotations(
233
- scores=scores,
234
- max_scores=self.scores,
235
- rotations=self.rotations,
236
- rotation_index=rotation_index,
237
- )
238
- return None
276
+ new_rotation_mapping, out_shape = {}, None
277
+ for i in range(len(states)):
278
+ if states[i] is None:
279
+ continue
239
280
 
240
- rotation = be.tobytes(rotation_matrix)
241
- with self._lock:
242
- rotation_index = self.rotation_mapping.setdefault(
243
- rotation, len(self.rotation_mapping)
244
- )
245
- internal_scores = be.from_sharedarr(self.scores)
246
- internal_rotations = be.from_sharedarr(self.rotations)
247
- internal_sores, internal_rotations = be.max_score_over_rotations(
248
- scores=scores,
249
- max_scores=internal_scores,
250
- rotations=internal_rotations,
251
- rotation_index=rotation_index,
252
- )
253
- return None
281
+ scores, offset, rotations, rotation_mapping = states[i]
282
+ if out_shape is None:
283
+ out_shape = np.zeros(scores.ndim, int)
284
+ out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
285
+
286
+ new_param = {}
287
+ for key, value in rotation_mapping.items():
288
+ rotation_bytes = be.tobytes(value)
289
+ new_param[rotation_bytes] = key
290
+ if rotation_bytes not in new_rotation_mapping:
291
+ new_rotation_mapping[rotation_bytes] = len(new_rotation_mapping)
292
+ states[i] = (scores, offset, rotations, new_param)
293
+ out_shape = tuple(int(x) for x in out_shape)
294
+ return new_rotation_mapping, out_shape, states
254
295
 
255
296
  @classmethod
256
- def merge(cls, param_stores: List[Tuple], **kwargs) -> Tuple:
297
+ def merge(cls, results: List[Tuple], **kwargs) -> Tuple:
257
298
  """
258
299
  Merge multiple instances of the current class.
259
300
 
260
301
  Parameters
261
302
  ----------
262
- param_stores : list of tuple
263
- List of instance's internal state created by applying `tuple(instance)`.
303
+ results : list of tuple
304
+ List of instance's internal state created by applying `result`.
264
305
  **kwargs : dict, optional
265
306
  Optional keyword arguments.
266
307
 
@@ -276,8 +317,8 @@ class MaxScoreOverRotations:
276
317
  Mapping between rotations and rotation indices.
277
318
  """
278
319
  use_memmap = kwargs.get("use_memmap", False)
279
- if len(param_stores) == 1:
280
- ret = param_stores[0]
320
+ if len(results) == 1:
321
+ ret = results[0]
281
322
  if use_memmap:
282
323
  scores, offset, rotations, rotation_mapping = ret
283
324
  scores = array_to_memmap(scores)
@@ -287,25 +328,12 @@ class MaxScoreOverRotations:
287
328
  return ret
288
329
 
289
330
  # Determine output array shape and create consistent rotation map
290
- new_rotation_mapping, out_shape = {}, None
291
- for i in range(len(param_stores)):
292
- if param_stores[i] is None:
293
- continue
294
-
295
- scores, offset, rotations, rotation_mapping = param_stores[i]
296
- if out_shape is None:
297
- out_shape = np.zeros(scores.ndim, int)
298
- scores_dtype, rotations_dtype = scores.dtype, rotations.dtype
299
- out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
300
-
301
- for key, value in rotation_mapping.items():
302
- if key not in new_rotation_mapping:
303
- new_rotation_mapping[key] = len(new_rotation_mapping)
304
-
331
+ master_rotation_mapping, out_shape, results = cls._harmonize_states(results)
305
332
  if out_shape is None:
306
333
  return None
307
334
 
308
- out_shape = tuple(int(x) for x in out_shape)
335
+ scores_dtype = results[0][0].dtype
336
+ rotations_dtype = results[0][2].dtype
309
337
  if use_memmap:
310
338
  scores_out_filename = generate_tempfile_name()
311
339
  rotations_out_filename = generate_tempfile_name()
@@ -331,8 +359,8 @@ class MaxScoreOverRotations:
331
359
  )
332
360
  rotations_out = np.full(out_shape, fill_value=-1, dtype=rotations_dtype)
333
361
 
334
- for i in range(len(param_stores)):
335
- if param_stores[i] is None:
362
+ for i in range(len(results)):
363
+ if results[i] is None:
336
364
  continue
337
365
 
338
366
  if use_memmap:
@@ -348,7 +376,7 @@ class MaxScoreOverRotations:
348
376
  shape=out_shape,
349
377
  dtype=rotations_dtype,
350
378
  )
351
- scores, offset, rotations, rotation_mapping = param_stores[i]
379
+ scores, offset, rotations, rotation_mapping = results[i]
352
380
  stops = np.add(offset, scores.shape).astype(int)
353
381
  indices = tuple(slice(*pos) for pos in zip(offset, stops))
354
382
 
@@ -359,7 +387,7 @@ class MaxScoreOverRotations:
359
387
  len(rotation_mapping) + 1, dtype=rotations_out.dtype
360
388
  )
361
389
  for key, value in rotation_mapping.items():
362
- lookup_table[value] = new_rotation_mapping[key]
390
+ lookup_table[value] = master_rotation_mapping[key]
363
391
 
364
392
  updated_rotations = rotations[indices_update]
365
393
  if len(updated_rotations):
@@ -372,7 +400,7 @@ class MaxScoreOverRotations:
372
400
  rotations_out.flush()
373
401
  scores_out, rotations_out = None, None
374
402
 
375
- param_stores[i] = None
403
+ results[i] = None
376
404
  scores, rotations = None, None
377
405
 
378
406
  if use_memmap:
@@ -390,13 +418,166 @@ class MaxScoreOverRotations:
390
418
  scores_out,
391
419
  np.zeros(scores_out.ndim, dtype=int),
392
420
  rotations_out,
393
- new_rotation_mapping,
421
+ cls._invert_rmap(master_rotation_mapping),
394
422
  )
395
423
 
396
- @property
397
- def is_shareable(self) -> bool:
398
- """Boolean indicating whether class instance can be shared across processes."""
399
- return True
424
+
425
+ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
426
+ """
427
+ Implements constrained template matching using rejection sampling.
428
+
429
+ Parameters
430
+ ----------
431
+ cone_angle : float
432
+ Maximum accepted rotational deviation in degrees.
433
+ positions : BackendArray
434
+ Array of shape (n, d) with n seed point translations.
435
+ positions : BackendArray
436
+ Array of shape (n, d, d) with n seed point rotation matrices.
437
+ reference : BackendArray
438
+ Reference orientation of the template, wlog defaults to (0,0,1).
439
+ acceptance_radius : int or tuple of ints
440
+ Translational acceptance radius around seed point in voxels.
441
+ **kwargs : dict, optional
442
+ Keyword aguments passed to the constructor of :py:class:`MaxScoreOverRotations`.
443
+ """
444
+
445
+ def __init__(
446
+ self,
447
+ cone_angle: float,
448
+ positions: BackendArray,
449
+ rotations: BackendArray,
450
+ reference: BackendArray = (0, 0, 1),
451
+ acceptance_radius: int = 10,
452
+ **kwargs,
453
+ ):
454
+ MaxScoreOverRotations.__init__(self, **kwargs)
455
+
456
+ if not isinstance(acceptance_radius, (int, Tuple)):
457
+ raise ValueError("acceptance_radius needs to be of type int or tuple.")
458
+
459
+ if isinstance(acceptance_radius, int):
460
+ acceptance_radius = (
461
+ acceptance_radius,
462
+ acceptance_radius,
463
+ acceptance_radius,
464
+ )
465
+ acceptance_radius = tuple(int(x) for x in acceptance_radius)
466
+
467
+ self._cone_angle = float(np.radians(cone_angle))
468
+ self._cone_cutoff = float(np.tan(self._cone_angle))
469
+ self._reference = be.astype(
470
+ be.reshape(be.to_backend_array(reference), (-1,)), be._float_dtype
471
+ )
472
+ positions = be.astype(be.to_backend_array(positions), be._int_dtype)
473
+
474
+ ndim = positions.shape[1]
475
+ rotate_mask = len(set(acceptance_radius)) != 1
476
+ extend = max(acceptance_radius)
477
+ mask = create_mask(
478
+ mask_type="ellipse",
479
+ radius=acceptance_radius,
480
+ shape=tuple(2 * extend + 1 for _ in range(ndim)),
481
+ center=tuple(extend for _ in range(ndim)),
482
+ )
483
+ self._score_mask = be.astype(be.to_backend_array(mask), be._float_dtype)
484
+
485
+ # Map position from real space to shifted score space
486
+ lower_limit = be.to_backend_array(self._offset)
487
+ positions = be.subtract(positions, lower_limit)
488
+ positions, valid_positions = cart_to_score(
489
+ positions=positions,
490
+ fast_shape=kwargs.get("fast_shape", None),
491
+ targetshape=kwargs.get("targetshape", None),
492
+ templateshape=kwargs.get("templateshape", None),
493
+ fourier_shift=kwargs.get("fourier_shift", None),
494
+ convolution_mode=kwargs.get("convolution_mode", None),
495
+ convolution_shape=kwargs.get("convolution_shape", None),
496
+ )
497
+
498
+ self._positions = positions[valid_positions]
499
+ rotations = be.to_backend_array(rotations)[valid_positions]
500
+ ex = be.astype(be.to_backend_array((1, 0, 0)), be._float_dtype)
501
+ ey = be.astype(be.to_backend_array((0, 1, 0)), be._float_dtype)
502
+ ez = be.astype(be.to_backend_array((0, 0, 1)), be._float_dtype)
503
+
504
+ self._normals_x = (rotations @ ex[..., None])[..., 0]
505
+ self._normals_y = (rotations @ ey[..., None])[..., 0]
506
+ self._normals_z = (rotations @ ez[..., None])[..., 0]
507
+
508
+ # Periodic wrapping could be avoided by padding the target
509
+ shape = be.to_backend_array(self._shape)
510
+ starts = be.subtract(self._positions, extend)
511
+ ret, (n, d), mshape = [], self._positions.shape, self._score_mask.shape
512
+ if starts.shape[0] > 0:
513
+ for i in range(d):
514
+ indices = starts[:, slice(i, i + 1)] + be.arange(mshape[i])[None]
515
+ indices = be.mod(indices, shape[i], out=indices)
516
+ indices_shape = (n, *tuple(1 if k != i else -1 for k in range(d)))
517
+ ret.append(be.reshape(indices, indices_shape))
518
+
519
+ self._index_grid = tuple(ret)
520
+ self._mask_shape = tuple(1 if i != 0 else -1 for i in range(1 + ndim))
521
+
522
+ if rotate_mask:
523
+ self._score_mask = be.zeros(
524
+ (rotations.shape[0], *self._score_mask.shape), dtype=be._float_dtype
525
+ )
526
+ for i in range(rotations.shape[0]):
527
+ mask = create_mask(
528
+ mask_type="ellipse",
529
+ radius=acceptance_radius,
530
+ shape=tuple(2 * extend + 1 for _ in range(ndim)),
531
+ center=tuple(extend for _ in range(ndim)),
532
+ orientation=be.to_numpy_array(rotations[i]),
533
+ )
534
+ self._score_mask[i] = be.astype(
535
+ be.to_backend_array(mask), be._float_dtype
536
+ )
537
+
538
+ def __call__(
539
+ self, state: Tuple, scores: BackendArray, rotation_matrix: BackendArray
540
+ ) -> Tuple:
541
+ mask = self._get_constraint(rotation_matrix)
542
+ mask = self._get_score_mask(mask=mask, scores=scores)
543
+
544
+ scores = be.multiply(scores, mask, out=scores)
545
+ return super().__call__(state, scores=scores, rotation_matrix=rotation_matrix)
546
+
547
+ def _get_constraint(self, rotation_matrix: BackendArray) -> BackendArray:
548
+ """
549
+ Determine whether the angle between projection of reference w.r.t to
550
+ a given rotation matrix and a set of rotations fall within the set
551
+ cone_angle cutoff.
552
+
553
+ Parameters
554
+ ----------
555
+ rotation_matrix : BackendArray
556
+ Rotation matrix with shape (d,d).
557
+
558
+ Returns
559
+ -------
560
+ BackerndArray
561
+ Boolean mask of shape (n, )
562
+ """
563
+ template_rot = rotation_matrix @ self._reference
564
+
565
+ x = be.sum(be.multiply(self._normals_x, template_rot), axis=1)
566
+ y = be.sum(be.multiply(self._normals_y, template_rot), axis=1)
567
+ z = be.sum(be.multiply(self._normals_z, template_rot), axis=1)
568
+
569
+ return be.sqrt(x**2 + y**2) <= (z * self._cone_cutoff)
570
+
571
+ def _get_score_mask(self, mask: BackendArray, scores: BackendArray, **kwargs):
572
+ score_mask = be.zeros(scores.shape, scores.dtype)
573
+
574
+ if be.sum(mask) == 0:
575
+ return score_mask
576
+ mask = be.reshape(mask, self._mask_shape)
577
+
578
+ score_mask = be.addat(score_mask, self._index_grid, self._score_mask * mask)
579
+ return score_mask > 0
580
+
400
581
 
401
582
  class MaxScoreOverTranslations(MaxScoreOverRotations):
402
583
  """
@@ -436,43 +617,40 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
436
617
  super().__init__(
437
618
  shape=shape_reduced, shm_handler=shm_handler, offset=offset, **kwargs
438
619
  )
439
-
440
- self.rotations = be.full(1, dtype=be._int_dtype, fill_value=-1)
441
- self.rotations = be.to_sharedarr(self.rotations, shm_handler)
442
620
  self._aggregate_axis = aggregate_axis
443
621
 
444
- def __call__(self, scores: BackendArray, rotation_matrix: BackendArray):
445
- if self._lock_is_nullcontext:
446
- rotation_index = len(self.rotation_mapping)
447
- if self._inversion_mapping:
448
- self.rotation_mapping[rotation_index] = rotation_matrix
449
- else:
450
- rotation = be.tobytes(rotation_matrix)
451
- rotation_index = self.rotation_mapping.setdefault(
452
- rotation, rotation_index
453
- )
454
- max_score = be.max(scores, axis=self._aggregate_axis)
455
- self.scores[rotation_index] = max_score
456
- return None
622
+ def init_state(self):
623
+ scores = be.full(
624
+ shape=self._shape, dtype=be._float_dtype, fill_value=self._score_threshold
625
+ )
626
+ rotations = be.full(1, dtype=be._int_dtype, fill_value=-1)
627
+ return scores, rotations, {}
457
628
 
458
- rotation = be.tobytes(rotation_matrix)
459
- with self._lock:
460
- rotation_index = self.rotation_mapping.setdefault(
461
- rotation, len(self.rotation_mapping)
462
- )
463
- internal_scores = be.from_sharedarr(self.scores)
464
- max_score = be.max(scores, axis=self._aggregate_axis)
465
- internal_scores[rotation_index] = max_score
466
- return None
629
+ def __call__(
630
+ self, state, scores: BackendArray, rotation_matrix: BackendArray
631
+ ) -> Tuple:
632
+ prev_scores, rotations, rotation_mapping = state
633
+
634
+ rotation_index = len(rotation_mapping)
635
+ if self._inversion_mapping:
636
+ rotation_mapping[rotation_index] = rotation_matrix
637
+ else:
638
+ rotation = be.tobytes(rotation_matrix)
639
+ rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
640
+ max_score = be.max(scores, axis=self._aggregate_axis)
641
+
642
+ update = prev_scores[rotation_index]
643
+ update = be.maximum(max_score, update, out=update)
644
+ return prev_scores, rotations, rotation_mapping
467
645
 
468
646
  @classmethod
469
- def merge(cls, param_stores: List[Tuple], **kwargs) -> Tuple:
647
+ def merge(cls, states: List[Tuple], **kwargs) -> Tuple:
470
648
  """
471
649
  Merge multiple instances of the current class.
472
650
 
473
651
  Parameters
474
652
  ----------
475
- param_stores : list of tuple
653
+ states : list of tuple
476
654
  List of instance's internal state created by applying `tuple(instance)`.
477
655
  **kwargs : dict, optional
478
656
  Optional keyword arguments.
@@ -488,31 +666,18 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
488
666
  Dict
489
667
  Mapping between rotations and rotation indices.
490
668
  """
491
- if len(param_stores) == 1:
492
- return param_stores[0]
669
+ if len(states) == 1:
670
+ return states[0]
493
671
 
494
672
  # Determine output array shape and create consistent rotation map
495
- new_rotation_mapping, out_shape = {}, None
496
- for i in range(len(param_stores)):
497
- if param_stores[i] is None:
498
- continue
499
-
500
- scores, offset, rotations, rotation_mapping = param_stores[i]
501
- if out_shape is None:
502
- out_shape = np.zeros(scores.ndim, int)
503
- scores_dtype, rotations_out = scores.dtype, rotations
504
- out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
505
-
506
- for key, value in rotation_mapping.items():
507
- if key not in new_rotation_mapping:
508
- new_rotation_mapping[key] = len(new_rotation_mapping)
509
-
673
+ states, master_rotation_mapping, out_shape = cls._harmonize_states(states)
510
674
  if out_shape is None:
511
675
  return None
512
676
 
513
- out_shape[0] = len(new_rotation_mapping)
677
+ out_shape[0] = len(master_rotation_mapping)
514
678
  out_shape = tuple(int(x) for x in out_shape)
515
679
 
680
+ scores_dtype = states[0][0].dtype
516
681
  use_memmap = kwargs.get("use_memmap", False)
517
682
  if use_memmap:
518
683
  scores_out_filename = generate_tempfile_name()
@@ -528,8 +693,8 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
528
693
  dtype=scores_dtype,
529
694
  )
530
695
 
531
- for i in range(len(param_stores)):
532
- if param_stores[i] is None:
696
+ for i in range(len(states)):
697
+ if states[i] is None:
533
698
  continue
534
699
 
535
700
  if use_memmap:
@@ -539,11 +704,11 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
539
704
  shape=out_shape,
540
705
  dtype=scores_dtype,
541
706
  )
542
- scores, offset, rotations, rotation_mapping = param_stores[i]
707
+ scores, offset, rotations, rotation_mapping = states[i]
543
708
 
544
709
  outer_table = np.arange(len(rotation_mapping), dtype=int)
545
710
  lookup_table = np.array(
546
- [new_rotation_mapping[key] for key in rotation_mapping.keys()],
711
+ [master_rotation_mapping[key] for key in rotation_mapping.keys()],
547
712
  dtype=int,
548
713
  )
549
714
 
@@ -559,7 +724,7 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
559
724
  scores_out.flush()
560
725
  scores_out = None
561
726
 
562
- param_stores[i], scores = None, None
727
+ states[i], scores = None, None
563
728
 
564
729
  if use_memmap:
565
730
  scores_out = np.memmap(
@@ -569,8 +734,8 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
569
734
  return (
570
735
  scores_out,
571
736
  np.zeros(scores_out.ndim, dtype=int),
572
- rotations_out,
573
- new_rotation_mapping,
737
+ states[2],
738
+ cls._invert_rmap(master_rotation_mapping),
574
739
  )
575
740
 
576
741
  def _postprocess(self, **kwargs):