pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
- pytme-0.2.9.data/scripts/match_template.py +1135 -0
- pytme-0.2.9.data/scripts/postprocess.py +622 -0
- pytme-0.2.9.data/scripts/preprocess.py +209 -0
- pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
- pytme-0.2.9.dist-info/METADATA +95 -0
- pytme-0.2.9.dist-info/RECORD +119 -0
- pytme-0.2.9.dist-info/WHEEL +5 -0
- pytme-0.2.9.dist-info/entry_points.txt +6 -0
- pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
- pytme-0.2.9.dist-info/top_level.txt +3 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +1135 -0
- scripts/postprocess.py +622 -0
- scripts/preprocess.py +209 -0
- scripts/preprocessor_gui.py +1227 -0
- tests/__init__.py +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +216 -0
- tests/test_backends.py +446 -0
- tests/test_density.py +503 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +124 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +226 -0
- tests/test_matching_utils.py +189 -0
- tests/test_orientations.py +175 -0
- tests/test_parser.py +33 -0
- tests/test_rotations.py +153 -0
- tests/test_structure.py +247 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +186 -0
- tme/analyzer/aggregation.py +577 -0
- tme/analyzer/peaks.py +953 -0
- tme/backends/__init__.py +171 -0
- tme/backends/_cupy_utils.py +734 -0
- tme/backends/_jax_utils.py +188 -0
- tme/backends/cupy_backend.py +294 -0
- tme/backends/jax_backend.py +314 -0
- tme/backends/matching_backend.py +1270 -0
- tme/backends/mlx_backend.py +241 -0
- tme/backends/npfftw_backend.py +583 -0
- tme/backends/pytorch_backend.py +430 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2263 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/filters/__init__.py +6 -0
- tme/filters/_utils.py +311 -0
- tme/filters/bandpass.py +230 -0
- tme/filters/compose.py +81 -0
- tme/filters/ctf.py +393 -0
- tme/filters/reconstruction.py +160 -0
- tme/filters/wedge.py +542 -0
- tme/filters/whitening.py +191 -0
- tme/matching_data.py +863 -0
- tme/matching_exhaustive.py +497 -0
- tme/matching_optimization.py +1311 -0
- tme/matching_scores.py +1183 -0
- tme/matching_utils.py +1188 -0
- tme/memory.py +337 -0
- tme/orientations.py +598 -0
- tme/parser.py +685 -0
- tme/preprocessor.py +1329 -0
- tme/rotations.py +350 -0
- tme/structure.py +1864 -0
- tme/types.py +13 -0
@@ -0,0 +1,577 @@
|
|
1
|
+
""" Implements classes to analyze outputs from exhaustive template matching.
|
2
|
+
|
3
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
from contextlib import nullcontext
|
9
|
+
from multiprocessing import Manager
|
10
|
+
from typing import Tuple, List, Dict, Generator
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
|
14
|
+
from ..types import BackendArray
|
15
|
+
from ._utils import cart_to_score
|
16
|
+
from ..backends import backend as be
|
17
|
+
from ..matching_utils import (
|
18
|
+
create_mask,
|
19
|
+
array_to_memmap,
|
20
|
+
generate_tempfile_name,
|
21
|
+
apply_convolution_mode,
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
"MaxScoreOverRotations",
|
27
|
+
"MaxScoreOverTranslations",
|
28
|
+
]
|
29
|
+
|
30
|
+
|
31
|
+
class MaxScoreOverRotations:
|
32
|
+
"""
|
33
|
+
Determine the rotation maximizing the score over all possible translations.
|
34
|
+
|
35
|
+
Parameters
|
36
|
+
----------
|
37
|
+
shape : tuple of int
|
38
|
+
Shape of array passed to :py:meth:`MaxScoreOverRotations.__call__`.
|
39
|
+
scores : BackendArray, optional
|
40
|
+
Array mapping translations to scores.
|
41
|
+
rotations : BackendArray, optional
|
42
|
+
Array mapping translations to rotation indices.
|
43
|
+
offset : BackendArray, optional
|
44
|
+
Coordinate origin considered during merging, zero by default.
|
45
|
+
score_threshold : float, optional
|
46
|
+
Minimum score to be considered, zero by default.
|
47
|
+
shm_handler : :class:`multiprocessing.managers.SharedMemoryManager`, optional
|
48
|
+
Shared memory manager, defaults to memory not being shared.
|
49
|
+
use_memmap : bool, optional
|
50
|
+
Memmap internal arrays, False by default.
|
51
|
+
thread_safe: bool, optional
|
52
|
+
Allow class to be modified by multiple processes, True by default.
|
53
|
+
only_unique_rotations : bool, optional
|
54
|
+
Whether each rotation will be shown only once, False by default.
|
55
|
+
|
56
|
+
Attributes
|
57
|
+
----------
|
58
|
+
scores : BackendArray
|
59
|
+
Mapping of translations to scores.
|
60
|
+
rotations : BackendArray
|
61
|
+
Mmapping of translations to rotation indices.
|
62
|
+
rotation_mapping : Dict
|
63
|
+
Mapping of rotations to rotation indices.
|
64
|
+
offset : BackendArray, optional
|
65
|
+
Coordinate origin considered during merging, zero by default
|
66
|
+
|
67
|
+
Examples
|
68
|
+
--------
|
69
|
+
The following achieves the minimal definition of a :py:class:`MaxScoreOverRotations`
|
70
|
+
instance
|
71
|
+
|
72
|
+
>>> from tme.analyzer import MaxScoreOverRotations
|
73
|
+
>>> analyzer = MaxScoreOverRotations(shape = (50, 50))
|
74
|
+
|
75
|
+
The following simulates a template matching run by creating random data for a range
|
76
|
+
of rotations and sending it to ``analyzer`` via its __call__ method
|
77
|
+
|
78
|
+
>>> for rotation_number in range(10):
|
79
|
+
>>> scores = np.random.rand(50,50)
|
80
|
+
>>> rotation = np.random.rand(scores.ndim, scores.ndim)
|
81
|
+
>>> analyzer(scores = scores, rotation_matrix = rotation)
|
82
|
+
|
83
|
+
The aggregated scores can be extracted by invoking the __iter__ method of
|
84
|
+
``analyzer``
|
85
|
+
|
86
|
+
>>> results = tuple(analyzer)
|
87
|
+
|
88
|
+
The ``results`` tuple contains (1) the maximum scores for each translation,
|
89
|
+
(2) an offset which is relevant when merging results from split template matching
|
90
|
+
using :py:meth:`MaxScoreOverRotations.merge`, (3) the rotation used to obtain a
|
91
|
+
score for a given translation, (4) a dictionary mapping rotation matrices to the
|
92
|
+
indices used in (2).
|
93
|
+
|
94
|
+
We can extract the ``optimal_score``, ``optimal_translation`` and ``optimal_rotation``
|
95
|
+
as follows
|
96
|
+
|
97
|
+
>>> optimal_score = results[0].max()
|
98
|
+
>>> optimal_translation = np.where(results[0] == results[0].max())
|
99
|
+
>>> optimal_rotation_index = results[2][optimal_translation]
|
100
|
+
>>> for key, value in results[3].items():
|
101
|
+
>>> if value != optimal_rotation_index:
|
102
|
+
>>> continue
|
103
|
+
>>> optimal_rotation = np.frombuffer(key, rotation.dtype)
|
104
|
+
>>> optimal_rotation = optimal_rotation.reshape(scores.ndim, scores.ndim)
|
105
|
+
|
106
|
+
The outlined procedure is a trivial method to identify high scoring peaks.
|
107
|
+
Alternatively, :py:class:`PeakCaller` offers a range of more elaborate approaches
|
108
|
+
that can be used.
|
109
|
+
"""
|
110
|
+
|
111
|
+
def __init__(
|
112
|
+
self,
|
113
|
+
shape: Tuple[int],
|
114
|
+
scores: BackendArray = None,
|
115
|
+
rotations: BackendArray = None,
|
116
|
+
offset: BackendArray = None,
|
117
|
+
score_threshold: float = 0,
|
118
|
+
shm_handler: object = None,
|
119
|
+
use_memmap: bool = False,
|
120
|
+
thread_safe: bool = True,
|
121
|
+
only_unique_rotations: bool = False,
|
122
|
+
**kwargs,
|
123
|
+
):
|
124
|
+
self._shape = tuple(int(x) for x in shape)
|
125
|
+
|
126
|
+
self.scores = scores
|
127
|
+
if self.scores is None:
|
128
|
+
self.scores = be.full(
|
129
|
+
shape=self._shape, dtype=be._float_dtype, fill_value=score_threshold
|
130
|
+
)
|
131
|
+
self.rotations = rotations
|
132
|
+
if self.rotations is None:
|
133
|
+
self.rotations = be.full(self._shape, dtype=be._int_dtype, fill_value=-1)
|
134
|
+
|
135
|
+
self.scores = be.to_sharedarr(self.scores, shm_handler)
|
136
|
+
self.rotations = be.to_sharedarr(self.rotations, shm_handler)
|
137
|
+
|
138
|
+
if offset is None:
|
139
|
+
offset = be.zeros(len(self._shape), be._int_dtype)
|
140
|
+
self.offset = be.astype(be.to_backend_array(offset), int)
|
141
|
+
|
142
|
+
self._use_memmap = use_memmap
|
143
|
+
self._lock = Manager().Lock() if thread_safe else nullcontext()
|
144
|
+
self._lock_is_nullcontext = isinstance(self.scores, type(be.zeros((1))))
|
145
|
+
self._inversion_mapping = self._lock_is_nullcontext and only_unique_rotations
|
146
|
+
self.rotation_mapping = Manager().dict() if thread_safe else {}
|
147
|
+
|
148
|
+
def _postprocess(
|
149
|
+
self,
|
150
|
+
targetshape: Tuple[int],
|
151
|
+
templateshape: Tuple[int],
|
152
|
+
convolution_shape: Tuple[int],
|
153
|
+
fourier_shift: Tuple[int] = None,
|
154
|
+
convolution_mode: str = None,
|
155
|
+
shm_handler=None,
|
156
|
+
**kwargs,
|
157
|
+
) -> "MaxScoreOverRotations":
|
158
|
+
"""Correct padding to Fourier shape and convolution mode."""
|
159
|
+
scores = be.from_sharedarr(self.scores)
|
160
|
+
rotations = be.from_sharedarr(self.rotations)
|
161
|
+
if fourier_shift is not None:
|
162
|
+
axis = tuple(i for i in range(len(fourier_shift)))
|
163
|
+
scores = be.roll(scores, shift=fourier_shift, axis=axis)
|
164
|
+
rotations = be.roll(rotations, shift=fourier_shift, axis=axis)
|
165
|
+
|
166
|
+
convargs = {
|
167
|
+
"s1": targetshape,
|
168
|
+
"s2": templateshape,
|
169
|
+
"convolution_mode": convolution_mode,
|
170
|
+
"convolution_shape": convolution_shape,
|
171
|
+
}
|
172
|
+
if convolution_mode is not None:
|
173
|
+
scores = apply_convolution_mode(scores, **convargs)
|
174
|
+
rotations = apply_convolution_mode(rotations, **convargs)
|
175
|
+
|
176
|
+
self._shape, self.scores, self.rotations = scores.shape, scores, rotations
|
177
|
+
if shm_handler is not None:
|
178
|
+
self.scores = be.to_sharedarr(scores, shm_handler)
|
179
|
+
self.rotations = be.to_sharedarr(rotations, shm_handler)
|
180
|
+
return self
|
181
|
+
|
182
|
+
def __iter__(self) -> Generator:
|
183
|
+
scores = be.from_sharedarr(self.scores)
|
184
|
+
rotations = be.from_sharedarr(self.rotations)
|
185
|
+
|
186
|
+
scores = be.to_numpy_array(scores)
|
187
|
+
rotations = be.to_numpy_array(rotations)
|
188
|
+
if self._use_memmap:
|
189
|
+
scores = array_to_memmap(scores)
|
190
|
+
rotations = array_to_memmap(rotations)
|
191
|
+
else:
|
192
|
+
if type(self.scores) is not type(scores):
|
193
|
+
# Copy to avoid invalidation by shared memory handler
|
194
|
+
scores, rotations = scores.copy(), rotations.copy()
|
195
|
+
|
196
|
+
if self._inversion_mapping:
|
197
|
+
self.rotation_mapping = {
|
198
|
+
be.tobytes(v): k for k, v in self.rotation_mapping.items()
|
199
|
+
}
|
200
|
+
|
201
|
+
param_store = (
|
202
|
+
scores,
|
203
|
+
be.to_numpy_array(self.offset),
|
204
|
+
rotations,
|
205
|
+
dict(self.rotation_mapping),
|
206
|
+
)
|
207
|
+
yield from param_store
|
208
|
+
|
209
|
+
def __call__(self, scores: BackendArray, rotation_matrix: BackendArray):
|
210
|
+
"""
|
211
|
+
Update the parameter store.
|
212
|
+
|
213
|
+
Parameters
|
214
|
+
----------
|
215
|
+
scores : BackendArray
|
216
|
+
Array of scores.
|
217
|
+
rotation_matrix : BackendArray
|
218
|
+
Square matrix describing the current rotation.
|
219
|
+
"""
|
220
|
+
# be.tobytes behaviour caused overhead for certain GPU/CUDA combinations
|
221
|
+
# If the analyzer is not shared and each rotation is unique, we can
|
222
|
+
# use index to rotation mapping and invert prior to merging.
|
223
|
+
if self._lock_is_nullcontext:
|
224
|
+
rotation_index = len(self.rotation_mapping)
|
225
|
+
if self._inversion_mapping:
|
226
|
+
self.rotation_mapping[rotation_index] = rotation_matrix
|
227
|
+
else:
|
228
|
+
rotation = be.tobytes(rotation_matrix)
|
229
|
+
rotation_index = self.rotation_mapping.setdefault(
|
230
|
+
rotation, rotation_index
|
231
|
+
)
|
232
|
+
self.scores, self.rotations = be.max_score_over_rotations(
|
233
|
+
scores=scores,
|
234
|
+
max_scores=self.scores,
|
235
|
+
rotations=self.rotations,
|
236
|
+
rotation_index=rotation_index,
|
237
|
+
)
|
238
|
+
return None
|
239
|
+
|
240
|
+
rotation = be.tobytes(rotation_matrix)
|
241
|
+
with self._lock:
|
242
|
+
rotation_index = self.rotation_mapping.setdefault(
|
243
|
+
rotation, len(self.rotation_mapping)
|
244
|
+
)
|
245
|
+
internal_scores = be.from_sharedarr(self.scores)
|
246
|
+
internal_rotations = be.from_sharedarr(self.rotations)
|
247
|
+
internal_sores, internal_rotations = be.max_score_over_rotations(
|
248
|
+
scores=scores,
|
249
|
+
max_scores=internal_scores,
|
250
|
+
rotations=internal_rotations,
|
251
|
+
rotation_index=rotation_index,
|
252
|
+
)
|
253
|
+
return None
|
254
|
+
|
255
|
+
@classmethod
|
256
|
+
def merge(cls, param_stores: List[Tuple], **kwargs) -> Tuple:
|
257
|
+
"""
|
258
|
+
Merge multiple instances of the current class.
|
259
|
+
|
260
|
+
Parameters
|
261
|
+
----------
|
262
|
+
param_stores : list of tuple
|
263
|
+
List of instance's internal state created by applying `tuple(instance)`.
|
264
|
+
**kwargs : dict, optional
|
265
|
+
Optional keyword arguments.
|
266
|
+
|
267
|
+
Returns
|
268
|
+
-------
|
269
|
+
NDArray
|
270
|
+
Maximum score of each translation over all observed rotations.
|
271
|
+
NDArray
|
272
|
+
Translation offset, zero by default.
|
273
|
+
NDArray
|
274
|
+
Mapping between translations and rotation indices.
|
275
|
+
Dict
|
276
|
+
Mapping between rotations and rotation indices.
|
277
|
+
"""
|
278
|
+
use_memmap = kwargs.get("use_memmap", False)
|
279
|
+
if len(param_stores) == 1:
|
280
|
+
ret = param_stores[0]
|
281
|
+
if use_memmap:
|
282
|
+
scores, offset, rotations, rotation_mapping = ret
|
283
|
+
scores = array_to_memmap(scores)
|
284
|
+
rotations = array_to_memmap(rotations)
|
285
|
+
ret = (scores, offset, rotations, rotation_mapping)
|
286
|
+
|
287
|
+
return ret
|
288
|
+
|
289
|
+
# Determine output array shape and create consistent rotation map
|
290
|
+
new_rotation_mapping, out_shape = {}, None
|
291
|
+
for i in range(len(param_stores)):
|
292
|
+
if param_stores[i] is None:
|
293
|
+
continue
|
294
|
+
|
295
|
+
scores, offset, rotations, rotation_mapping = param_stores[i]
|
296
|
+
if out_shape is None:
|
297
|
+
out_shape = np.zeros(scores.ndim, int)
|
298
|
+
scores_dtype, rotations_dtype = scores.dtype, rotations.dtype
|
299
|
+
out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
|
300
|
+
|
301
|
+
for key, value in rotation_mapping.items():
|
302
|
+
if key not in new_rotation_mapping:
|
303
|
+
new_rotation_mapping[key] = len(new_rotation_mapping)
|
304
|
+
|
305
|
+
if out_shape is None:
|
306
|
+
return None
|
307
|
+
|
308
|
+
out_shape = tuple(int(x) for x in out_shape)
|
309
|
+
if use_memmap:
|
310
|
+
scores_out_filename = generate_tempfile_name()
|
311
|
+
rotations_out_filename = generate_tempfile_name()
|
312
|
+
|
313
|
+
scores_out = np.memmap(
|
314
|
+
scores_out_filename, mode="w+", shape=out_shape, dtype=scores_dtype
|
315
|
+
)
|
316
|
+
scores_out.fill(kwargs.get("score_threshold", 0))
|
317
|
+
scores_out.flush()
|
318
|
+
rotations_out = np.memmap(
|
319
|
+
rotations_out_filename,
|
320
|
+
mode="w+",
|
321
|
+
shape=out_shape,
|
322
|
+
dtype=rotations_dtype,
|
323
|
+
)
|
324
|
+
rotations_out.fill(-1)
|
325
|
+
rotations_out.flush()
|
326
|
+
else:
|
327
|
+
scores_out = np.full(
|
328
|
+
out_shape,
|
329
|
+
fill_value=kwargs.get("score_threshold", 0),
|
330
|
+
dtype=scores_dtype,
|
331
|
+
)
|
332
|
+
rotations_out = np.full(out_shape, fill_value=-1, dtype=rotations_dtype)
|
333
|
+
|
334
|
+
for i in range(len(param_stores)):
|
335
|
+
if param_stores[i] is None:
|
336
|
+
continue
|
337
|
+
|
338
|
+
if use_memmap:
|
339
|
+
scores_out = np.memmap(
|
340
|
+
scores_out_filename,
|
341
|
+
mode="r+",
|
342
|
+
shape=out_shape,
|
343
|
+
dtype=scores_dtype,
|
344
|
+
)
|
345
|
+
rotations_out = np.memmap(
|
346
|
+
rotations_out_filename,
|
347
|
+
mode="r+",
|
348
|
+
shape=out_shape,
|
349
|
+
dtype=rotations_dtype,
|
350
|
+
)
|
351
|
+
scores, offset, rotations, rotation_mapping = param_stores[i]
|
352
|
+
stops = np.add(offset, scores.shape).astype(int)
|
353
|
+
indices = tuple(slice(*pos) for pos in zip(offset, stops))
|
354
|
+
|
355
|
+
indices_update = scores > scores_out[indices]
|
356
|
+
scores_out[indices][indices_update] = scores[indices_update]
|
357
|
+
|
358
|
+
lookup_table = np.arange(
|
359
|
+
len(rotation_mapping) + 1, dtype=rotations_out.dtype
|
360
|
+
)
|
361
|
+
for key, value in rotation_mapping.items():
|
362
|
+
lookup_table[value] = new_rotation_mapping[key]
|
363
|
+
|
364
|
+
updated_rotations = rotations[indices_update]
|
365
|
+
if len(updated_rotations):
|
366
|
+
rotations_out[indices][indices_update] = lookup_table[updated_rotations]
|
367
|
+
|
368
|
+
if use_memmap:
|
369
|
+
scores._mmap.close()
|
370
|
+
rotations._mmap.close()
|
371
|
+
scores_out.flush()
|
372
|
+
rotations_out.flush()
|
373
|
+
scores_out, rotations_out = None, None
|
374
|
+
|
375
|
+
param_stores[i] = None
|
376
|
+
scores, rotations = None, None
|
377
|
+
|
378
|
+
if use_memmap:
|
379
|
+
scores_out = np.memmap(
|
380
|
+
scores_out_filename, mode="r", shape=out_shape, dtype=scores_dtype
|
381
|
+
)
|
382
|
+
rotations_out = np.memmap(
|
383
|
+
rotations_out_filename,
|
384
|
+
mode="r",
|
385
|
+
shape=out_shape,
|
386
|
+
dtype=rotations_dtype,
|
387
|
+
)
|
388
|
+
|
389
|
+
return (
|
390
|
+
scores_out,
|
391
|
+
np.zeros(scores_out.ndim, dtype=int),
|
392
|
+
rotations_out,
|
393
|
+
new_rotation_mapping,
|
394
|
+
)
|
395
|
+
|
396
|
+
@property
|
397
|
+
def is_shareable(self) -> bool:
|
398
|
+
"""Boolean indicating whether class instance can be shared across processes."""
|
399
|
+
return True
|
400
|
+
|
401
|
+
class MaxScoreOverTranslations(MaxScoreOverRotations):
|
402
|
+
"""
|
403
|
+
Determine the translation maximizing the score over all possible rotations.
|
404
|
+
|
405
|
+
Parameters
|
406
|
+
----------
|
407
|
+
shape : tuple of int
|
408
|
+
Shape of array passed to :py:meth:`MaxScoreOverTranslations.__call__`.
|
409
|
+
n_rotations : int
|
410
|
+
Number of rotations to aggregate over.
|
411
|
+
aggregate_axis : tuple of int, optional
|
412
|
+
Array axis to aggregate over, None by default.
|
413
|
+
shm_handler : :class:`multiprocessing.managers.SharedMemoryManager`, optional
|
414
|
+
Shared memory manager, defaults to memory not being shared.
|
415
|
+
**kwargs: dict, optional
|
416
|
+
Keyword arguments passed to the constructor of the parent class.
|
417
|
+
"""
|
418
|
+
|
419
|
+
def __init__(
|
420
|
+
self,
|
421
|
+
shape: Tuple[int],
|
422
|
+
n_rotations: int,
|
423
|
+
aggregate_axis: Tuple[int] = None,
|
424
|
+
shm_handler: object = None,
|
425
|
+
offset: Tuple[int] = None,
|
426
|
+
**kwargs: Dict,
|
427
|
+
):
|
428
|
+
shape_reduced = [x for i, x in enumerate(shape) if i not in aggregate_axis]
|
429
|
+
shape_reduced.insert(0, n_rotations)
|
430
|
+
|
431
|
+
if offset is None:
|
432
|
+
offset = be.zeros(len(shape), be._int_dtype)
|
433
|
+
offset = [x for i, x in enumerate(offset) if i not in aggregate_axis]
|
434
|
+
offset.insert(0, 0)
|
435
|
+
|
436
|
+
super().__init__(
|
437
|
+
shape=shape_reduced, shm_handler=shm_handler, offset=offset, **kwargs
|
438
|
+
)
|
439
|
+
|
440
|
+
self.rotations = be.full(1, dtype=be._int_dtype, fill_value=-1)
|
441
|
+
self.rotations = be.to_sharedarr(self.rotations, shm_handler)
|
442
|
+
self._aggregate_axis = aggregate_axis
|
443
|
+
|
444
|
+
def __call__(self, scores: BackendArray, rotation_matrix: BackendArray):
|
445
|
+
if self._lock_is_nullcontext:
|
446
|
+
rotation_index = len(self.rotation_mapping)
|
447
|
+
if self._inversion_mapping:
|
448
|
+
self.rotation_mapping[rotation_index] = rotation_matrix
|
449
|
+
else:
|
450
|
+
rotation = be.tobytes(rotation_matrix)
|
451
|
+
rotation_index = self.rotation_mapping.setdefault(
|
452
|
+
rotation, rotation_index
|
453
|
+
)
|
454
|
+
max_score = be.max(scores, axis=self._aggregate_axis)
|
455
|
+
self.scores[rotation_index] = max_score
|
456
|
+
return None
|
457
|
+
|
458
|
+
rotation = be.tobytes(rotation_matrix)
|
459
|
+
with self._lock:
|
460
|
+
rotation_index = self.rotation_mapping.setdefault(
|
461
|
+
rotation, len(self.rotation_mapping)
|
462
|
+
)
|
463
|
+
internal_scores = be.from_sharedarr(self.scores)
|
464
|
+
max_score = be.max(scores, axis=self._aggregate_axis)
|
465
|
+
internal_scores[rotation_index] = max_score
|
466
|
+
return None
|
467
|
+
|
468
|
+
@classmethod
|
469
|
+
def merge(cls, param_stores: List[Tuple], **kwargs) -> Tuple:
|
470
|
+
"""
|
471
|
+
Merge multiple instances of the current class.
|
472
|
+
|
473
|
+
Parameters
|
474
|
+
----------
|
475
|
+
param_stores : list of tuple
|
476
|
+
List of instance's internal state created by applying `tuple(instance)`.
|
477
|
+
**kwargs : dict, optional
|
478
|
+
Optional keyword arguments.
|
479
|
+
|
480
|
+
Returns
|
481
|
+
-------
|
482
|
+
NDArray
|
483
|
+
Maximum score of each rotation over all observed translations.
|
484
|
+
NDArray
|
485
|
+
Translation offset, zero by default.
|
486
|
+
NDArray
|
487
|
+
Mapping between translations and rotation indices.
|
488
|
+
Dict
|
489
|
+
Mapping between rotations and rotation indices.
|
490
|
+
"""
|
491
|
+
if len(param_stores) == 1:
|
492
|
+
return param_stores[0]
|
493
|
+
|
494
|
+
# Determine output array shape and create consistent rotation map
|
495
|
+
new_rotation_mapping, out_shape = {}, None
|
496
|
+
for i in range(len(param_stores)):
|
497
|
+
if param_stores[i] is None:
|
498
|
+
continue
|
499
|
+
|
500
|
+
scores, offset, rotations, rotation_mapping = param_stores[i]
|
501
|
+
if out_shape is None:
|
502
|
+
out_shape = np.zeros(scores.ndim, int)
|
503
|
+
scores_dtype, rotations_out = scores.dtype, rotations
|
504
|
+
out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
|
505
|
+
|
506
|
+
for key, value in rotation_mapping.items():
|
507
|
+
if key not in new_rotation_mapping:
|
508
|
+
new_rotation_mapping[key] = len(new_rotation_mapping)
|
509
|
+
|
510
|
+
if out_shape is None:
|
511
|
+
return None
|
512
|
+
|
513
|
+
out_shape[0] = len(new_rotation_mapping)
|
514
|
+
out_shape = tuple(int(x) for x in out_shape)
|
515
|
+
|
516
|
+
use_memmap = kwargs.get("use_memmap", False)
|
517
|
+
if use_memmap:
|
518
|
+
scores_out_filename = generate_tempfile_name()
|
519
|
+
scores_out = np.memmap(
|
520
|
+
scores_out_filename, mode="w+", shape=out_shape, dtype=scores_dtype
|
521
|
+
)
|
522
|
+
scores_out.fill(kwargs.get("score_threshold", 0))
|
523
|
+
scores_out.flush()
|
524
|
+
else:
|
525
|
+
scores_out = np.full(
|
526
|
+
out_shape,
|
527
|
+
fill_value=kwargs.get("score_threshold", 0),
|
528
|
+
dtype=scores_dtype,
|
529
|
+
)
|
530
|
+
|
531
|
+
for i in range(len(param_stores)):
|
532
|
+
if param_stores[i] is None:
|
533
|
+
continue
|
534
|
+
|
535
|
+
if use_memmap:
|
536
|
+
scores_out = np.memmap(
|
537
|
+
scores_out_filename,
|
538
|
+
mode="r+",
|
539
|
+
shape=out_shape,
|
540
|
+
dtype=scores_dtype,
|
541
|
+
)
|
542
|
+
scores, offset, rotations, rotation_mapping = param_stores[i]
|
543
|
+
|
544
|
+
outer_table = np.arange(len(rotation_mapping), dtype=int)
|
545
|
+
lookup_table = np.array(
|
546
|
+
[new_rotation_mapping[key] for key in rotation_mapping.keys()],
|
547
|
+
dtype=int,
|
548
|
+
)
|
549
|
+
|
550
|
+
stops = np.add(offset, scores.shape).astype(int)
|
551
|
+
indices = [slice(*pos) for pos in zip(offset[1:], stops[1:])]
|
552
|
+
indices.insert(0, lookup_table)
|
553
|
+
indices = tuple(indices)
|
554
|
+
|
555
|
+
scores_out[indices] = np.maximum(scores_out[indices], scores[outer_table])
|
556
|
+
|
557
|
+
if use_memmap:
|
558
|
+
scores._mmap.close()
|
559
|
+
scores_out.flush()
|
560
|
+
scores_out = None
|
561
|
+
|
562
|
+
param_stores[i], scores = None, None
|
563
|
+
|
564
|
+
if use_memmap:
|
565
|
+
scores_out = np.memmap(
|
566
|
+
scores_out_filename, mode="r", shape=out_shape, dtype=scores_dtype
|
567
|
+
)
|
568
|
+
|
569
|
+
return (
|
570
|
+
scores_out,
|
571
|
+
np.zeros(scores_out.ndim, dtype=int),
|
572
|
+
rotations_out,
|
573
|
+
new_rotation_mapping,
|
574
|
+
)
|
575
|
+
|
576
|
+
def _postprocess(self, **kwargs):
|
577
|
+
return self
|