pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
@@ -0,0 +1,577 @@
1
+ """ Implements classes to analyze outputs from exhaustive template matching.
2
+
3
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ from contextlib import nullcontext
9
+ from multiprocessing import Manager
10
+ from typing import Tuple, List, Dict, Generator
11
+
12
+ import numpy as np
13
+
14
+ from ..types import BackendArray
15
+ from ._utils import cart_to_score
16
+ from ..backends import backend as be
17
+ from ..matching_utils import (
18
+ create_mask,
19
+ array_to_memmap,
20
+ generate_tempfile_name,
21
+ apply_convolution_mode,
22
+ )
23
+
24
+
25
+ __all__ = [
26
+ "MaxScoreOverRotations",
27
+ "MaxScoreOverTranslations",
28
+ ]
29
+
30
+
31
+ class MaxScoreOverRotations:
32
+ """
33
+ Determine the rotation maximizing the score over all possible translations.
34
+
35
+ Parameters
36
+ ----------
37
+ shape : tuple of int
38
+ Shape of array passed to :py:meth:`MaxScoreOverRotations.__call__`.
39
+ scores : BackendArray, optional
40
+ Array mapping translations to scores.
41
+ rotations : BackendArray, optional
42
+ Array mapping translations to rotation indices.
43
+ offset : BackendArray, optional
44
+ Coordinate origin considered during merging, zero by default.
45
+ score_threshold : float, optional
46
+ Minimum score to be considered, zero by default.
47
+ shm_handler : :class:`multiprocessing.managers.SharedMemoryManager`, optional
48
+ Shared memory manager, defaults to memory not being shared.
49
+ use_memmap : bool, optional
50
+ Memmap internal arrays, False by default.
51
+ thread_safe: bool, optional
52
+ Allow class to be modified by multiple processes, True by default.
53
+ only_unique_rotations : bool, optional
54
+ Whether each rotation will be shown only once, False by default.
55
+
56
+ Attributes
57
+ ----------
58
+ scores : BackendArray
59
+ Mapping of translations to scores.
60
+ rotations : BackendArray
61
+ Mmapping of translations to rotation indices.
62
+ rotation_mapping : Dict
63
+ Mapping of rotations to rotation indices.
64
+ offset : BackendArray, optional
65
+ Coordinate origin considered during merging, zero by default
66
+
67
+ Examples
68
+ --------
69
+ The following achieves the minimal definition of a :py:class:`MaxScoreOverRotations`
70
+ instance
71
+
72
+ >>> from tme.analyzer import MaxScoreOverRotations
73
+ >>> analyzer = MaxScoreOverRotations(shape = (50, 50))
74
+
75
+ The following simulates a template matching run by creating random data for a range
76
+ of rotations and sending it to ``analyzer`` via its __call__ method
77
+
78
+ >>> for rotation_number in range(10):
79
+ >>> scores = np.random.rand(50,50)
80
+ >>> rotation = np.random.rand(scores.ndim, scores.ndim)
81
+ >>> analyzer(scores = scores, rotation_matrix = rotation)
82
+
83
+ The aggregated scores can be extracted by invoking the __iter__ method of
84
+ ``analyzer``
85
+
86
+ >>> results = tuple(analyzer)
87
+
88
+ The ``results`` tuple contains (1) the maximum scores for each translation,
89
+ (2) an offset which is relevant when merging results from split template matching
90
+ using :py:meth:`MaxScoreOverRotations.merge`, (3) the rotation used to obtain a
91
+ score for a given translation, (4) a dictionary mapping rotation matrices to the
92
+ indices used in (2).
93
+
94
+ We can extract the ``optimal_score``, ``optimal_translation`` and ``optimal_rotation``
95
+ as follows
96
+
97
+ >>> optimal_score = results[0].max()
98
+ >>> optimal_translation = np.where(results[0] == results[0].max())
99
+ >>> optimal_rotation_index = results[2][optimal_translation]
100
+ >>> for key, value in results[3].items():
101
+ >>> if value != optimal_rotation_index:
102
+ >>> continue
103
+ >>> optimal_rotation = np.frombuffer(key, rotation.dtype)
104
+ >>> optimal_rotation = optimal_rotation.reshape(scores.ndim, scores.ndim)
105
+
106
+ The outlined procedure is a trivial method to identify high scoring peaks.
107
+ Alternatively, :py:class:`PeakCaller` offers a range of more elaborate approaches
108
+ that can be used.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ shape: Tuple[int],
114
+ scores: BackendArray = None,
115
+ rotations: BackendArray = None,
116
+ offset: BackendArray = None,
117
+ score_threshold: float = 0,
118
+ shm_handler: object = None,
119
+ use_memmap: bool = False,
120
+ thread_safe: bool = True,
121
+ only_unique_rotations: bool = False,
122
+ **kwargs,
123
+ ):
124
+ self._shape = tuple(int(x) for x in shape)
125
+
126
+ self.scores = scores
127
+ if self.scores is None:
128
+ self.scores = be.full(
129
+ shape=self._shape, dtype=be._float_dtype, fill_value=score_threshold
130
+ )
131
+ self.rotations = rotations
132
+ if self.rotations is None:
133
+ self.rotations = be.full(self._shape, dtype=be._int_dtype, fill_value=-1)
134
+
135
+ self.scores = be.to_sharedarr(self.scores, shm_handler)
136
+ self.rotations = be.to_sharedarr(self.rotations, shm_handler)
137
+
138
+ if offset is None:
139
+ offset = be.zeros(len(self._shape), be._int_dtype)
140
+ self.offset = be.astype(be.to_backend_array(offset), int)
141
+
142
+ self._use_memmap = use_memmap
143
+ self._lock = Manager().Lock() if thread_safe else nullcontext()
144
+ self._lock_is_nullcontext = isinstance(self.scores, type(be.zeros((1))))
145
+ self._inversion_mapping = self._lock_is_nullcontext and only_unique_rotations
146
+ self.rotation_mapping = Manager().dict() if thread_safe else {}
147
+
148
+ def _postprocess(
149
+ self,
150
+ targetshape: Tuple[int],
151
+ templateshape: Tuple[int],
152
+ convolution_shape: Tuple[int],
153
+ fourier_shift: Tuple[int] = None,
154
+ convolution_mode: str = None,
155
+ shm_handler=None,
156
+ **kwargs,
157
+ ) -> "MaxScoreOverRotations":
158
+ """Correct padding to Fourier shape and convolution mode."""
159
+ scores = be.from_sharedarr(self.scores)
160
+ rotations = be.from_sharedarr(self.rotations)
161
+ if fourier_shift is not None:
162
+ axis = tuple(i for i in range(len(fourier_shift)))
163
+ scores = be.roll(scores, shift=fourier_shift, axis=axis)
164
+ rotations = be.roll(rotations, shift=fourier_shift, axis=axis)
165
+
166
+ convargs = {
167
+ "s1": targetshape,
168
+ "s2": templateshape,
169
+ "convolution_mode": convolution_mode,
170
+ "convolution_shape": convolution_shape,
171
+ }
172
+ if convolution_mode is not None:
173
+ scores = apply_convolution_mode(scores, **convargs)
174
+ rotations = apply_convolution_mode(rotations, **convargs)
175
+
176
+ self._shape, self.scores, self.rotations = scores.shape, scores, rotations
177
+ if shm_handler is not None:
178
+ self.scores = be.to_sharedarr(scores, shm_handler)
179
+ self.rotations = be.to_sharedarr(rotations, shm_handler)
180
+ return self
181
+
182
+ def __iter__(self) -> Generator:
183
+ scores = be.from_sharedarr(self.scores)
184
+ rotations = be.from_sharedarr(self.rotations)
185
+
186
+ scores = be.to_numpy_array(scores)
187
+ rotations = be.to_numpy_array(rotations)
188
+ if self._use_memmap:
189
+ scores = array_to_memmap(scores)
190
+ rotations = array_to_memmap(rotations)
191
+ else:
192
+ if type(self.scores) is not type(scores):
193
+ # Copy to avoid invalidation by shared memory handler
194
+ scores, rotations = scores.copy(), rotations.copy()
195
+
196
+ if self._inversion_mapping:
197
+ self.rotation_mapping = {
198
+ be.tobytes(v): k for k, v in self.rotation_mapping.items()
199
+ }
200
+
201
+ param_store = (
202
+ scores,
203
+ be.to_numpy_array(self.offset),
204
+ rotations,
205
+ dict(self.rotation_mapping),
206
+ )
207
+ yield from param_store
208
+
209
+ def __call__(self, scores: BackendArray, rotation_matrix: BackendArray):
210
+ """
211
+ Update the parameter store.
212
+
213
+ Parameters
214
+ ----------
215
+ scores : BackendArray
216
+ Array of scores.
217
+ rotation_matrix : BackendArray
218
+ Square matrix describing the current rotation.
219
+ """
220
+ # be.tobytes behaviour caused overhead for certain GPU/CUDA combinations
221
+ # If the analyzer is not shared and each rotation is unique, we can
222
+ # use index to rotation mapping and invert prior to merging.
223
+ if self._lock_is_nullcontext:
224
+ rotation_index = len(self.rotation_mapping)
225
+ if self._inversion_mapping:
226
+ self.rotation_mapping[rotation_index] = rotation_matrix
227
+ else:
228
+ rotation = be.tobytes(rotation_matrix)
229
+ rotation_index = self.rotation_mapping.setdefault(
230
+ rotation, rotation_index
231
+ )
232
+ self.scores, self.rotations = be.max_score_over_rotations(
233
+ scores=scores,
234
+ max_scores=self.scores,
235
+ rotations=self.rotations,
236
+ rotation_index=rotation_index,
237
+ )
238
+ return None
239
+
240
+ rotation = be.tobytes(rotation_matrix)
241
+ with self._lock:
242
+ rotation_index = self.rotation_mapping.setdefault(
243
+ rotation, len(self.rotation_mapping)
244
+ )
245
+ internal_scores = be.from_sharedarr(self.scores)
246
+ internal_rotations = be.from_sharedarr(self.rotations)
247
+ internal_sores, internal_rotations = be.max_score_over_rotations(
248
+ scores=scores,
249
+ max_scores=internal_scores,
250
+ rotations=internal_rotations,
251
+ rotation_index=rotation_index,
252
+ )
253
+ return None
254
+
255
+ @classmethod
256
+ def merge(cls, param_stores: List[Tuple], **kwargs) -> Tuple:
257
+ """
258
+ Merge multiple instances of the current class.
259
+
260
+ Parameters
261
+ ----------
262
+ param_stores : list of tuple
263
+ List of instance's internal state created by applying `tuple(instance)`.
264
+ **kwargs : dict, optional
265
+ Optional keyword arguments.
266
+
267
+ Returns
268
+ -------
269
+ NDArray
270
+ Maximum score of each translation over all observed rotations.
271
+ NDArray
272
+ Translation offset, zero by default.
273
+ NDArray
274
+ Mapping between translations and rotation indices.
275
+ Dict
276
+ Mapping between rotations and rotation indices.
277
+ """
278
+ use_memmap = kwargs.get("use_memmap", False)
279
+ if len(param_stores) == 1:
280
+ ret = param_stores[0]
281
+ if use_memmap:
282
+ scores, offset, rotations, rotation_mapping = ret
283
+ scores = array_to_memmap(scores)
284
+ rotations = array_to_memmap(rotations)
285
+ ret = (scores, offset, rotations, rotation_mapping)
286
+
287
+ return ret
288
+
289
+ # Determine output array shape and create consistent rotation map
290
+ new_rotation_mapping, out_shape = {}, None
291
+ for i in range(len(param_stores)):
292
+ if param_stores[i] is None:
293
+ continue
294
+
295
+ scores, offset, rotations, rotation_mapping = param_stores[i]
296
+ if out_shape is None:
297
+ out_shape = np.zeros(scores.ndim, int)
298
+ scores_dtype, rotations_dtype = scores.dtype, rotations.dtype
299
+ out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
300
+
301
+ for key, value in rotation_mapping.items():
302
+ if key not in new_rotation_mapping:
303
+ new_rotation_mapping[key] = len(new_rotation_mapping)
304
+
305
+ if out_shape is None:
306
+ return None
307
+
308
+ out_shape = tuple(int(x) for x in out_shape)
309
+ if use_memmap:
310
+ scores_out_filename = generate_tempfile_name()
311
+ rotations_out_filename = generate_tempfile_name()
312
+
313
+ scores_out = np.memmap(
314
+ scores_out_filename, mode="w+", shape=out_shape, dtype=scores_dtype
315
+ )
316
+ scores_out.fill(kwargs.get("score_threshold", 0))
317
+ scores_out.flush()
318
+ rotations_out = np.memmap(
319
+ rotations_out_filename,
320
+ mode="w+",
321
+ shape=out_shape,
322
+ dtype=rotations_dtype,
323
+ )
324
+ rotations_out.fill(-1)
325
+ rotations_out.flush()
326
+ else:
327
+ scores_out = np.full(
328
+ out_shape,
329
+ fill_value=kwargs.get("score_threshold", 0),
330
+ dtype=scores_dtype,
331
+ )
332
+ rotations_out = np.full(out_shape, fill_value=-1, dtype=rotations_dtype)
333
+
334
+ for i in range(len(param_stores)):
335
+ if param_stores[i] is None:
336
+ continue
337
+
338
+ if use_memmap:
339
+ scores_out = np.memmap(
340
+ scores_out_filename,
341
+ mode="r+",
342
+ shape=out_shape,
343
+ dtype=scores_dtype,
344
+ )
345
+ rotations_out = np.memmap(
346
+ rotations_out_filename,
347
+ mode="r+",
348
+ shape=out_shape,
349
+ dtype=rotations_dtype,
350
+ )
351
+ scores, offset, rotations, rotation_mapping = param_stores[i]
352
+ stops = np.add(offset, scores.shape).astype(int)
353
+ indices = tuple(slice(*pos) for pos in zip(offset, stops))
354
+
355
+ indices_update = scores > scores_out[indices]
356
+ scores_out[indices][indices_update] = scores[indices_update]
357
+
358
+ lookup_table = np.arange(
359
+ len(rotation_mapping) + 1, dtype=rotations_out.dtype
360
+ )
361
+ for key, value in rotation_mapping.items():
362
+ lookup_table[value] = new_rotation_mapping[key]
363
+
364
+ updated_rotations = rotations[indices_update]
365
+ if len(updated_rotations):
366
+ rotations_out[indices][indices_update] = lookup_table[updated_rotations]
367
+
368
+ if use_memmap:
369
+ scores._mmap.close()
370
+ rotations._mmap.close()
371
+ scores_out.flush()
372
+ rotations_out.flush()
373
+ scores_out, rotations_out = None, None
374
+
375
+ param_stores[i] = None
376
+ scores, rotations = None, None
377
+
378
+ if use_memmap:
379
+ scores_out = np.memmap(
380
+ scores_out_filename, mode="r", shape=out_shape, dtype=scores_dtype
381
+ )
382
+ rotations_out = np.memmap(
383
+ rotations_out_filename,
384
+ mode="r",
385
+ shape=out_shape,
386
+ dtype=rotations_dtype,
387
+ )
388
+
389
+ return (
390
+ scores_out,
391
+ np.zeros(scores_out.ndim, dtype=int),
392
+ rotations_out,
393
+ new_rotation_mapping,
394
+ )
395
+
396
+ @property
397
+ def is_shareable(self) -> bool:
398
+ """Boolean indicating whether class instance can be shared across processes."""
399
+ return True
400
+
401
+ class MaxScoreOverTranslations(MaxScoreOverRotations):
402
+ """
403
+ Determine the translation maximizing the score over all possible rotations.
404
+
405
+ Parameters
406
+ ----------
407
+ shape : tuple of int
408
+ Shape of array passed to :py:meth:`MaxScoreOverTranslations.__call__`.
409
+ n_rotations : int
410
+ Number of rotations to aggregate over.
411
+ aggregate_axis : tuple of int, optional
412
+ Array axis to aggregate over, None by default.
413
+ shm_handler : :class:`multiprocessing.managers.SharedMemoryManager`, optional
414
+ Shared memory manager, defaults to memory not being shared.
415
+ **kwargs: dict, optional
416
+ Keyword arguments passed to the constructor of the parent class.
417
+ """
418
+
419
+ def __init__(
420
+ self,
421
+ shape: Tuple[int],
422
+ n_rotations: int,
423
+ aggregate_axis: Tuple[int] = None,
424
+ shm_handler: object = None,
425
+ offset: Tuple[int] = None,
426
+ **kwargs: Dict,
427
+ ):
428
+ shape_reduced = [x for i, x in enumerate(shape) if i not in aggregate_axis]
429
+ shape_reduced.insert(0, n_rotations)
430
+
431
+ if offset is None:
432
+ offset = be.zeros(len(shape), be._int_dtype)
433
+ offset = [x for i, x in enumerate(offset) if i not in aggregate_axis]
434
+ offset.insert(0, 0)
435
+
436
+ super().__init__(
437
+ shape=shape_reduced, shm_handler=shm_handler, offset=offset, **kwargs
438
+ )
439
+
440
+ self.rotations = be.full(1, dtype=be._int_dtype, fill_value=-1)
441
+ self.rotations = be.to_sharedarr(self.rotations, shm_handler)
442
+ self._aggregate_axis = aggregate_axis
443
+
444
+ def __call__(self, scores: BackendArray, rotation_matrix: BackendArray):
445
+ if self._lock_is_nullcontext:
446
+ rotation_index = len(self.rotation_mapping)
447
+ if self._inversion_mapping:
448
+ self.rotation_mapping[rotation_index] = rotation_matrix
449
+ else:
450
+ rotation = be.tobytes(rotation_matrix)
451
+ rotation_index = self.rotation_mapping.setdefault(
452
+ rotation, rotation_index
453
+ )
454
+ max_score = be.max(scores, axis=self._aggregate_axis)
455
+ self.scores[rotation_index] = max_score
456
+ return None
457
+
458
+ rotation = be.tobytes(rotation_matrix)
459
+ with self._lock:
460
+ rotation_index = self.rotation_mapping.setdefault(
461
+ rotation, len(self.rotation_mapping)
462
+ )
463
+ internal_scores = be.from_sharedarr(self.scores)
464
+ max_score = be.max(scores, axis=self._aggregate_axis)
465
+ internal_scores[rotation_index] = max_score
466
+ return None
467
+
468
+ @classmethod
469
+ def merge(cls, param_stores: List[Tuple], **kwargs) -> Tuple:
470
+ """
471
+ Merge multiple instances of the current class.
472
+
473
+ Parameters
474
+ ----------
475
+ param_stores : list of tuple
476
+ List of instance's internal state created by applying `tuple(instance)`.
477
+ **kwargs : dict, optional
478
+ Optional keyword arguments.
479
+
480
+ Returns
481
+ -------
482
+ NDArray
483
+ Maximum score of each rotation over all observed translations.
484
+ NDArray
485
+ Translation offset, zero by default.
486
+ NDArray
487
+ Mapping between translations and rotation indices.
488
+ Dict
489
+ Mapping between rotations and rotation indices.
490
+ """
491
+ if len(param_stores) == 1:
492
+ return param_stores[0]
493
+
494
+ # Determine output array shape and create consistent rotation map
495
+ new_rotation_mapping, out_shape = {}, None
496
+ for i in range(len(param_stores)):
497
+ if param_stores[i] is None:
498
+ continue
499
+
500
+ scores, offset, rotations, rotation_mapping = param_stores[i]
501
+ if out_shape is None:
502
+ out_shape = np.zeros(scores.ndim, int)
503
+ scores_dtype, rotations_out = scores.dtype, rotations
504
+ out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
505
+
506
+ for key, value in rotation_mapping.items():
507
+ if key not in new_rotation_mapping:
508
+ new_rotation_mapping[key] = len(new_rotation_mapping)
509
+
510
+ if out_shape is None:
511
+ return None
512
+
513
+ out_shape[0] = len(new_rotation_mapping)
514
+ out_shape = tuple(int(x) for x in out_shape)
515
+
516
+ use_memmap = kwargs.get("use_memmap", False)
517
+ if use_memmap:
518
+ scores_out_filename = generate_tempfile_name()
519
+ scores_out = np.memmap(
520
+ scores_out_filename, mode="w+", shape=out_shape, dtype=scores_dtype
521
+ )
522
+ scores_out.fill(kwargs.get("score_threshold", 0))
523
+ scores_out.flush()
524
+ else:
525
+ scores_out = np.full(
526
+ out_shape,
527
+ fill_value=kwargs.get("score_threshold", 0),
528
+ dtype=scores_dtype,
529
+ )
530
+
531
+ for i in range(len(param_stores)):
532
+ if param_stores[i] is None:
533
+ continue
534
+
535
+ if use_memmap:
536
+ scores_out = np.memmap(
537
+ scores_out_filename,
538
+ mode="r+",
539
+ shape=out_shape,
540
+ dtype=scores_dtype,
541
+ )
542
+ scores, offset, rotations, rotation_mapping = param_stores[i]
543
+
544
+ outer_table = np.arange(len(rotation_mapping), dtype=int)
545
+ lookup_table = np.array(
546
+ [new_rotation_mapping[key] for key in rotation_mapping.keys()],
547
+ dtype=int,
548
+ )
549
+
550
+ stops = np.add(offset, scores.shape).astype(int)
551
+ indices = [slice(*pos) for pos in zip(offset[1:], stops[1:])]
552
+ indices.insert(0, lookup_table)
553
+ indices = tuple(indices)
554
+
555
+ scores_out[indices] = np.maximum(scores_out[indices], scores[outer_table])
556
+
557
+ if use_memmap:
558
+ scores._mmap.close()
559
+ scores_out.flush()
560
+ scores_out = None
561
+
562
+ param_stores[i], scores = None, None
563
+
564
+ if use_memmap:
565
+ scores_out = np.memmap(
566
+ scores_out_filename, mode="r", shape=out_shape, dtype=scores_dtype
567
+ )
568
+
569
+ return (
570
+ scores_out,
571
+ np.zeros(scores_out.ndim, dtype=int),
572
+ rotations_out,
573
+ new_rotation_mapping,
574
+ )
575
+
576
+ def _postprocess(self, **kwargs):
577
+ return self