pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py → pytme-0.3b0.data/scripts/estimate_memory_usage.py +16 -33
- {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/match_template.py +224 -223
- {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/postprocess.py +283 -163
- {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/preprocess.py +11 -8
- {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/preprocessor_gui.py +10 -9
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/METADATA +11 -9
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/RECORD +61 -58
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/entry_points.txt +1 -1
- scripts/{estimate_ram_usage.py → estimate_memory_usage.py} +16 -33
- scripts/extract_candidates.py +224 -0
- scripts/match_template.py +224 -223
- scripts/postprocess.py +283 -163
- scripts/preprocess.py +11 -8
- scripts/preprocessor_gui.py +10 -9
- scripts/refine_matches.py +626 -0
- tests/preprocessing/test_frequency_filters.py +9 -4
- tests/test_analyzer.py +143 -138
- tests/test_matching_cli.py +85 -29
- tests/test_matching_exhaustive.py +1 -2
- tests/test_matching_optimization.py +4 -9
- tests/test_orientations.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +25 -17
- tme/analyzer/aggregation.py +385 -220
- tme/analyzer/base.py +138 -0
- tme/analyzer/peaks.py +150 -88
- tme/analyzer/proxy.py +122 -0
- tme/backends/__init__.py +4 -3
- tme/backends/_cupy_utils.py +25 -24
- tme/backends/_jax_utils.py +4 -3
- tme/backends/cupy_backend.py +4 -13
- tme/backends/jax_backend.py +6 -8
- tme/backends/matching_backend.py +4 -3
- tme/backends/mlx_backend.py +4 -3
- tme/backends/npfftw_backend.py +7 -5
- tme/backends/pytorch_backend.py +14 -4
- tme/cli.py +126 -0
- tme/density.py +4 -3
- tme/filters/__init__.py +1 -1
- tme/filters/_utils.py +4 -3
- tme/filters/bandpass.py +6 -4
- tme/filters/compose.py +5 -4
- tme/filters/ctf.py +426 -214
- tme/filters/reconstruction.py +58 -28
- tme/filters/wedge.py +139 -61
- tme/filters/whitening.py +36 -36
- tme/matching_data.py +4 -3
- tme/matching_exhaustive.py +17 -16
- tme/matching_optimization.py +5 -4
- tme/matching_scores.py +4 -3
- tme/matching_utils.py +6 -4
- tme/memory.py +4 -3
- tme/orientations.py +9 -6
- tme/parser.py +5 -4
- tme/preprocessor.py +4 -3
- tme/rotations.py +10 -7
- tme/structure.py +4 -3
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/WHEEL +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/top_level.txt +0 -0
tme/analyzer/base.py
ADDED
@@ -0,0 +1,138 @@
|
|
1
|
+
"""
|
2
|
+
Implements abstract base class for template matching analyzers.
|
3
|
+
|
4
|
+
Copyright (c) 2025 European Molecular Biology Laboratory
|
5
|
+
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
|
+
"""
|
8
|
+
|
9
|
+
from typing import Tuple, List
|
10
|
+
from abc import ABC, abstractmethod
|
11
|
+
|
12
|
+
__all__ = ["AbstractAnalyzer"]
|
13
|
+
|
14
|
+
|
15
|
+
class AbstractAnalyzer(ABC):
|
16
|
+
"""
|
17
|
+
Abstract base class for template matching analyzers.
|
18
|
+
"""
|
19
|
+
|
20
|
+
@property
|
21
|
+
def shareable(self):
|
22
|
+
"""
|
23
|
+
Indicate whether the analyzer can be shared across processes.
|
24
|
+
|
25
|
+
Returns
|
26
|
+
-------
|
27
|
+
bool
|
28
|
+
True if the analyzer supports shared memory operations
|
29
|
+
and can be safely used across multiple processes, False
|
30
|
+
if it should only be used within a single process.
|
31
|
+
"""
|
32
|
+
return False
|
33
|
+
|
34
|
+
@abstractmethod
|
35
|
+
def init_state(self, *args, **kwargs) -> Tuple:
|
36
|
+
"""
|
37
|
+
Initialize the analyzer state.
|
38
|
+
|
39
|
+
Returns
|
40
|
+
-------
|
41
|
+
state
|
42
|
+
Initial state tuple containing the analyzer's internal data
|
43
|
+
structures. The exact structure depends on the specific
|
44
|
+
implementation.
|
45
|
+
|
46
|
+
Notes
|
47
|
+
-----
|
48
|
+
This method creates the initial state that will be passed to
|
49
|
+
subsequent calls to __call__. The state should contain all
|
50
|
+
necessary data structures for accumulating analysis results.
|
51
|
+
"""
|
52
|
+
|
53
|
+
@abstractmethod
|
54
|
+
def __call__(self, state, scores, rotation_matrix, **kwargs) -> Tuple:
|
55
|
+
"""
|
56
|
+
Update the analyzer state with new scoring data.
|
57
|
+
|
58
|
+
Parameters
|
59
|
+
----------
|
60
|
+
state : object
|
61
|
+
Current analyzer state as returned by init_state() or
|
62
|
+
previous calls to __call__.
|
63
|
+
scores : BackendArray
|
64
|
+
Array of scores computed for the current rotation.
|
65
|
+
rotation_matrix : BackendArray
|
66
|
+
Rotation matrix used to generate the scores.
|
67
|
+
**kwargs : dict
|
68
|
+
Additional keyword arguments specific to the analyzer
|
69
|
+
implementation.
|
70
|
+
|
71
|
+
Returns
|
72
|
+
-------
|
73
|
+
state
|
74
|
+
Updated analyzer state with the new scoring data incorporated.
|
75
|
+
|
76
|
+
Notes
|
77
|
+
-----
|
78
|
+
This method should be pure functional - it should not modify
|
79
|
+
the input state but return a new state with the updates applied.
|
80
|
+
The exact signature may vary between implementations.
|
81
|
+
"""
|
82
|
+
pass
|
83
|
+
|
84
|
+
@abstractmethod
|
85
|
+
def result(self, state: Tuple, **kwargs) -> Tuple:
|
86
|
+
"""
|
87
|
+
Finalize the analysis and produce the final result.
|
88
|
+
|
89
|
+
Parameters
|
90
|
+
----------
|
91
|
+
state : tuple
|
92
|
+
Final analyzer state containing all accumulated data.
|
93
|
+
**kwargs : dict
|
94
|
+
Additional keyword arguments for result processing,
|
95
|
+
such as postprocessing parameters.
|
96
|
+
|
97
|
+
Returns
|
98
|
+
-------
|
99
|
+
result
|
100
|
+
Final analysis result. The exact format depends on the
|
101
|
+
analyzer implementation but typically includes processed
|
102
|
+
scores, rotation information, and metadata.
|
103
|
+
|
104
|
+
Notes
|
105
|
+
-----
|
106
|
+
This method converts the internal analyzer state into the
|
107
|
+
final output format expected by the template matching pipeline.
|
108
|
+
It may apply postprocessing operations like convolution mode
|
109
|
+
correction or coordinate transformations.
|
110
|
+
"""
|
111
|
+
pass
|
112
|
+
|
113
|
+
@classmethod
|
114
|
+
@abstractmethod
|
115
|
+
def merge(cls, results: List[Tuple], **kwargs) -> Tuple:
|
116
|
+
"""
|
117
|
+
Merge results from multiple analyzer instances.
|
118
|
+
|
119
|
+
Parameters
|
120
|
+
----------
|
121
|
+
results : list
|
122
|
+
List of result objects as returned by the result() method
|
123
|
+
from multiple analyzer instances.
|
124
|
+
**kwargs : dict
|
125
|
+
Additional keyword arguments for merge configuration.
|
126
|
+
|
127
|
+
Returns
|
128
|
+
-------
|
129
|
+
merged_result
|
130
|
+
Single result object combining all input results.
|
131
|
+
|
132
|
+
Notes
|
133
|
+
-----
|
134
|
+
This method enables parallel processing by allowing results
|
135
|
+
from different processes or splits to be combined into a
|
136
|
+
unified result. The merge operation should handle overlapping
|
137
|
+
data appropriately and maintain consistency.
|
138
|
+
"""
|
tme/analyzer/peaks.py
CHANGED
@@ -1,18 +1,20 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Implements classes to analyze outputs from exhaustive template matching.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
9
|
from functools import wraps
|
9
|
-
from abc import
|
10
|
-
from typing import Tuple, List, Dict
|
10
|
+
from abc import abstractmethod
|
11
|
+
from typing import Tuple, List, Dict
|
11
12
|
|
12
13
|
import numpy as np
|
13
14
|
from skimage.feature import peak_local_max
|
14
15
|
from skimage.registration._phase_cross_correlation import _upsampled_dft
|
15
16
|
|
17
|
+
from .base import AbstractAnalyzer
|
16
18
|
from ._utils import score_to_cart
|
17
19
|
from ..backends import backend as be
|
18
20
|
from ..matching_utils import split_shape
|
@@ -141,7 +143,7 @@ def batchify(shape: Tuple[int], batch_dims: Tuple[int] = None) -> List:
|
|
141
143
|
yield from _generate_slices_recursive(0, ())
|
142
144
|
|
143
145
|
|
144
|
-
class PeakCaller(
|
146
|
+
class PeakCaller(AbstractAnalyzer):
|
145
147
|
"""
|
146
148
|
Base class for peak calling algorithms.
|
147
149
|
|
@@ -190,16 +192,7 @@ class PeakCaller(ABC):
|
|
190
192
|
if min_boundary_distance < 0:
|
191
193
|
raise ValueError("min_boundary_distance has to be non-negative.")
|
192
194
|
|
193
|
-
|
194
|
-
self.translations = be.full(
|
195
|
-
(num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
|
196
|
-
)
|
197
|
-
self.rotations = be.full(
|
198
|
-
(num_peaks, ndim, ndim), fill_value=0, dtype=be._float_dtype
|
199
|
-
)
|
200
|
-
self.scores = be.full((num_peaks,), fill_value=0, dtype=be._float_dtype)
|
201
|
-
self.details = be.full((num_peaks,), fill_value=0, dtype=be._float_dtype)
|
202
|
-
|
195
|
+
self.shape = shape
|
203
196
|
self.num_peaks = int(num_peaks)
|
204
197
|
self.min_distance = int(min_distance)
|
205
198
|
self.min_boundary_distance = int(min_boundary_distance)
|
@@ -210,31 +203,47 @@ class PeakCaller(ABC):
|
|
210
203
|
|
211
204
|
self.min_score, self.max_score = min_score, max_score
|
212
205
|
|
213
|
-
|
214
|
-
|
215
|
-
self.convolution_mode = kwargs.get("convolution_mode", None)
|
216
|
-
self.targetshape = kwargs.get("targetshape", None)
|
217
|
-
self.templateshape = kwargs.get("templateshape", None)
|
218
|
-
|
219
|
-
def __iter__(self) -> Generator:
|
206
|
+
@abstractmethod
|
207
|
+
def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
|
220
208
|
"""
|
221
|
-
|
222
|
-
|
209
|
+
Call peaks in the score space.
|
210
|
+
|
211
|
+
Parameters
|
212
|
+
----------
|
213
|
+
scores : BackendArray
|
214
|
+
Score array to update analyzer with.
|
215
|
+
**kwargs : dict
|
216
|
+
Optional keyword arguments passed to underlying implementations.
|
217
|
+
|
218
|
+
Returns
|
219
|
+
-------
|
220
|
+
BackendArray
|
221
|
+
Peak positions (n, d).
|
222
|
+
BackendArray
|
223
|
+
Peak details (n, d).
|
223
224
|
"""
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
225
|
+
|
226
|
+
def init_state(self):
|
227
|
+
ndim = len(self.shape)
|
228
|
+
translations = be.full(
|
229
|
+
(self.num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
|
230
|
+
)
|
231
|
+
rotations = be.full(
|
232
|
+
(self.num_peaks, ndim, ndim), fill_value=0, dtype=be._float_dtype
|
233
|
+
)
|
234
|
+
for i in range(ndim):
|
235
|
+
rotations[:, i, i] = 1.0
|
236
|
+
|
237
|
+
scores = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
|
238
|
+
details = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
|
239
|
+
return translations, rotations, scores, details
|
231
240
|
|
232
241
|
def _get_peak_mask(self, peaks: BackendArray, scores: BackendArray) -> BackendArray:
|
233
242
|
if not len(peaks):
|
234
243
|
return None
|
235
244
|
|
236
245
|
valid_peaks = be.full((peaks.shape[0],), fill_value=1) == 1
|
237
|
-
if self.min_boundary_distance
|
246
|
+
if self.min_boundary_distance >= 0:
|
238
247
|
upper_limit = be.subtract(
|
239
248
|
be.to_backend_array(scores.shape), self.min_boundary_distance
|
240
249
|
)
|
@@ -315,20 +324,34 @@ class PeakCaller(ABC):
|
|
315
324
|
peak_positions = be.astype(peak_positions, int)
|
316
325
|
return peak_positions, peak_details
|
317
326
|
|
318
|
-
def __call__(
|
327
|
+
def __call__(
|
328
|
+
self,
|
329
|
+
state: Tuple,
|
330
|
+
scores: BackendArray,
|
331
|
+
rotation_matrix: BackendArray,
|
332
|
+
**kwargs,
|
333
|
+
) -> Tuple:
|
319
334
|
"""
|
320
335
|
Update the internal parameter store based on input array.
|
321
336
|
|
322
337
|
Parameters
|
323
338
|
----------
|
339
|
+
state : tuple
|
340
|
+
Current state tuple where:
|
341
|
+
- positions : BackendArray, (n, d) of peak positions
|
342
|
+
- rotations : BackendArray, (n, d, d) of correponding rotations
|
343
|
+
- scores : BackendArray, (n, ) of peak scores
|
344
|
+
- details : BackendArray, (n, ) of peak details
|
324
345
|
scores : BackendArray
|
325
|
-
|
346
|
+
Array of new scores to update analyzer with.
|
326
347
|
rotation_matrix : BackendArray
|
327
348
|
Rotation matrix used to obtain the score array.
|
328
349
|
**kwargs
|
329
350
|
Optional keyword aguments passed to :py:meth:`PeakCaller.call_peaks`.
|
330
351
|
"""
|
331
|
-
for ret in self._call_peaks(
|
352
|
+
for ret in self._call_peaks(
|
353
|
+
scores=scores, rotation_matrix=rotation_matrix, **kwargs
|
354
|
+
):
|
332
355
|
peak_positions, peak_details = ret
|
333
356
|
if peak_positions is None:
|
334
357
|
continue
|
@@ -341,7 +364,6 @@ class PeakCaller(ABC):
|
|
341
364
|
peak_scores = scores[tuple(peak_positions.T)]
|
342
365
|
if peak_details is not None:
|
343
366
|
peak_details = peak_details[valid_peaks]
|
344
|
-
# peak_details, peak_scores = peak_scores, -peak_details
|
345
367
|
else:
|
346
368
|
peak_details = be.full(peak_scores.shape, fill_value=-1)
|
347
369
|
|
@@ -351,66 +373,57 @@ class PeakCaller(ABC):
|
|
351
373
|
axis=0,
|
352
374
|
)
|
353
375
|
|
354
|
-
self._update(
|
376
|
+
state = self._update(
|
377
|
+
state,
|
355
378
|
peak_positions=peak_positions,
|
356
379
|
peak_details=peak_details,
|
357
380
|
peak_scores=peak_scores,
|
358
|
-
|
381
|
+
peak_rotations=rotations,
|
359
382
|
)
|
360
383
|
|
361
|
-
return
|
362
|
-
|
363
|
-
@abstractmethod
|
364
|
-
def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
|
365
|
-
"""
|
366
|
-
Call peaks in the score space.
|
367
|
-
|
368
|
-
Parameters
|
369
|
-
----------
|
370
|
-
scores : BackendArray
|
371
|
-
Score array.
|
372
|
-
**kwargs : dict
|
373
|
-
Optional keyword arguments passed to underlying implementations.
|
374
|
-
|
375
|
-
Returns
|
376
|
-
-------
|
377
|
-
Tuple[BackendArray, BackendArray]
|
378
|
-
Array of peak coordinates and peak details.
|
379
|
-
"""
|
384
|
+
return state
|
380
385
|
|
381
386
|
@classmethod
|
382
|
-
def merge(cls,
|
387
|
+
def merge(cls, results=List[Tuple], **kwargs) -> Tuple:
|
383
388
|
"""
|
384
389
|
Merge multiple instances of :py:class:`PeakCaller`.
|
385
390
|
|
386
391
|
Parameters
|
387
392
|
----------
|
388
|
-
|
389
|
-
|
393
|
+
results : list of tuple
|
394
|
+
List of instance results created by applying `result`.
|
390
395
|
**kwargs
|
391
396
|
Optional keyword arguments.
|
392
397
|
|
393
398
|
Returns
|
394
399
|
-------
|
395
|
-
|
396
|
-
|
400
|
+
NDArray
|
401
|
+
Peak positions (n, d).
|
402
|
+
NDArray
|
403
|
+
Peak rotation matrices (n, d, d).
|
404
|
+
NDArray
|
405
|
+
Peak scores (n, ).
|
406
|
+
NDArray
|
407
|
+
Peak details (n,).
|
397
408
|
"""
|
398
409
|
if "shape" not in kwargs:
|
399
|
-
kwargs["shape"] = tuple(1 for _ in range(
|
410
|
+
kwargs["shape"] = tuple(1 for _ in range(results[0][0].shape[1]))
|
400
411
|
|
401
412
|
base = cls(**kwargs)
|
402
|
-
|
403
|
-
|
413
|
+
base_state = base.init_state()
|
414
|
+
for result in results:
|
415
|
+
if len(result) == 0:
|
404
416
|
continue
|
405
|
-
peak_positions, rotations, peak_scores, peak_details =
|
406
|
-
base._update(
|
417
|
+
peak_positions, rotations, peak_scores, peak_details = result
|
418
|
+
base_state = base._update(
|
419
|
+
base_state,
|
407
420
|
peak_positions=be.to_backend_array(peak_positions),
|
408
421
|
peak_details=be.to_backend_array(peak_details),
|
409
422
|
peak_scores=be.to_backend_array(peak_scores),
|
410
|
-
|
423
|
+
peak_rotations=be.to_backend_array(rotations),
|
411
424
|
offset=kwargs.get("offset", None),
|
412
425
|
)
|
413
|
-
return
|
426
|
+
return base_state
|
414
427
|
|
415
428
|
@staticmethod
|
416
429
|
def oversample_peaks(
|
@@ -517,10 +530,11 @@ class PeakCaller(ABC):
|
|
517
530
|
|
518
531
|
def _update(
|
519
532
|
self,
|
533
|
+
state,
|
520
534
|
peak_positions: BackendArray,
|
521
535
|
peak_details: BackendArray,
|
522
536
|
peak_scores: BackendArray,
|
523
|
-
|
537
|
+
peak_rotations: BackendArray,
|
524
538
|
offset: BackendArray = None,
|
525
539
|
):
|
526
540
|
"""
|
@@ -539,14 +553,15 @@ class PeakCaller(ABC):
|
|
539
553
|
offset : BackendArray, optional
|
540
554
|
Translation offset, e.g. from splitting, (d, ).
|
541
555
|
"""
|
556
|
+
translations, rotations, scores, details = state
|
542
557
|
if offset is not None:
|
543
558
|
offset = be.astype(be.to_backend_array(offset), peak_positions.dtype)
|
544
559
|
peak_positions = be.add(peak_positions, offset, out=peak_positions)
|
545
560
|
|
546
|
-
positions = be.concatenate((
|
547
|
-
rotations = be.concatenate((
|
548
|
-
scores = be.concatenate((
|
549
|
-
details = be.concatenate((
|
561
|
+
positions = be.concatenate((translations, peak_positions))
|
562
|
+
rotations = be.concatenate((rotations, peak_rotations))
|
563
|
+
scores = be.concatenate((scores, peak_scores))
|
564
|
+
details = be.concatenate((details, peak_details))
|
550
565
|
|
551
566
|
# topk filtering after distances yields more distributed peak calls
|
552
567
|
distance_order = filter_points_indices(
|
@@ -561,22 +576,69 @@ class PeakCaller(ABC):
|
|
561
576
|
)
|
562
577
|
final_order = distance_order[top_scores]
|
563
578
|
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
579
|
+
translations = positions[final_order, :]
|
580
|
+
rotations = rotations[final_order, :]
|
581
|
+
scores = scores[final_order]
|
582
|
+
details = details[final_order]
|
583
|
+
return translations, rotations, scores, details
|
568
584
|
|
569
|
-
def
|
570
|
-
|
571
|
-
|
585
|
+
def result(
|
586
|
+
self,
|
587
|
+
state,
|
588
|
+
fast_shape: Tuple[int] = None,
|
589
|
+
targetshape: Tuple[int] = None,
|
590
|
+
templateshape: Tuple[int] = None,
|
591
|
+
convolution_shape: Tuple[int] = None,
|
592
|
+
fourier_shift: Tuple[int] = None,
|
593
|
+
convolution_mode: str = None,
|
594
|
+
**kwargs,
|
595
|
+
):
|
596
|
+
"""
|
597
|
+
Finalize the analysis result with optional postprocessing.
|
572
598
|
|
573
|
-
|
599
|
+
Parameters
|
600
|
+
----------
|
601
|
+
state : tuple
|
602
|
+
Current state tuple where:
|
603
|
+
- positions : BackendArray, (n, d) of peak positions
|
604
|
+
- rotations : BackendArray, (n, d, d) of correponding rotations
|
605
|
+
- scores : BackendArray, (n, ) of peak scores
|
606
|
+
- details : BackendArray, (n, ) of peak details
|
607
|
+
targetshape : Tuple[int], optional
|
608
|
+
Shape of the target for convolution mode correction.
|
609
|
+
templateshape : Tuple[int], optional
|
610
|
+
Shape of the template for convolution mode correction.
|
611
|
+
convolution_shape : Tuple[int], optional
|
612
|
+
Shape used for convolution.
|
613
|
+
fourier_shift : Tuple[int], optional.
|
614
|
+
Shift to apply for Fourier correction.
|
615
|
+
convolution_mode : str, optional
|
616
|
+
Convolution mode for padding correction.
|
617
|
+
**kwargs
|
618
|
+
Additional keyword arguments.
|
574
619
|
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
620
|
+
Returns
|
621
|
+
-------
|
622
|
+
tuple
|
623
|
+
Final result tuple (positions, rotations, scores, details).
|
624
|
+
"""
|
625
|
+
translations, rotations, scores, details = state
|
626
|
+
|
627
|
+
positions, valid_peaks = score_to_cart(
|
628
|
+
positions=translations,
|
629
|
+
fast_shape=fast_shape,
|
630
|
+
targetshape=targetshape,
|
631
|
+
templateshape=templateshape,
|
632
|
+
convolution_shape=convolution_shape,
|
633
|
+
fourier_shift=fourier_shift,
|
634
|
+
convolution_mode=convolution_mode,
|
635
|
+
**kwargs,
|
636
|
+
)
|
637
|
+
translations = be.to_cpu_array(positions[valid_peaks])
|
638
|
+
rotations = be.to_cpu_array(rotations[valid_peaks])
|
639
|
+
scores = be.to_cpu_array(scores[valid_peaks])
|
640
|
+
details = be.to_cpu_array(details[valid_peaks])
|
641
|
+
return translations, rotations, scores, details
|
580
642
|
|
581
643
|
|
582
644
|
class PeakCallerSort(PeakCaller):
|
tme/analyzer/proxy.py
ADDED
@@ -0,0 +1,122 @@
|
|
1
|
+
"""
|
2
|
+
Implements SharedAnalyzerProxy to managed shared memory of Analyzer instances
|
3
|
+
across different tasks.
|
4
|
+
|
5
|
+
This is primarily useful for CPU template matching, where parallelization can
|
6
|
+
be performed over rotations, rather than subsections of a large input volume.
|
7
|
+
|
8
|
+
Copyright (c) 2025 European Molecular Biology Laboratory
|
9
|
+
|
10
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
11
|
+
"""
|
12
|
+
|
13
|
+
from typing import Tuple
|
14
|
+
from multiprocessing import Manager
|
15
|
+
from multiprocessing.shared_memory import SharedMemory
|
16
|
+
|
17
|
+
from ..backends import backend as be
|
18
|
+
|
19
|
+
__all__ = ["StatelessSharedAnalyzerProxy", "SharedAnalyzerProxy"]
|
20
|
+
|
21
|
+
class StatelessSharedAnalyzerProxy:
|
22
|
+
"""
|
23
|
+
Proxy that wraps functional analyzers for concurrent access via shared memory.
|
24
|
+
|
25
|
+
Enables multiple processes/threads to safely update the same analyzer
|
26
|
+
while preserving the functional interface of the underlying analyzer.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(self, analyzer_class: type, analyzer_params: dict):
|
30
|
+
self._shared = False
|
31
|
+
self._process = self._direct_call
|
32
|
+
|
33
|
+
self._analyzer = analyzer_class(**analyzer_params)
|
34
|
+
|
35
|
+
def __call__(self, state, *args, **kwargs):
|
36
|
+
return self._process(state, *args, **kwargs)
|
37
|
+
|
38
|
+
def init_state(self, shm_handler=None, *args, **kwargs) -> Tuple:
|
39
|
+
state = self._analyzer.init_state()
|
40
|
+
if shm_handler is not None:
|
41
|
+
self._shared = True
|
42
|
+
state = self._to_shared(state, shm_handler)
|
43
|
+
|
44
|
+
self._lock = Manager().Lock()
|
45
|
+
self._process = self._thread_safe_call
|
46
|
+
return state
|
47
|
+
|
48
|
+
def _to_shared(self, state: Tuple, shm_handler):
|
49
|
+
backend_arr = type(be.zeros((1), dtype=be._float_dtype))
|
50
|
+
|
51
|
+
ret = []
|
52
|
+
for v in state:
|
53
|
+
if isinstance(v, backend_arr):
|
54
|
+
v = be.to_sharedarr(v, shm_handler)
|
55
|
+
elif isinstance(v, dict):
|
56
|
+
v = Manager().dict(**v)
|
57
|
+
ret.append(v)
|
58
|
+
return tuple(ret)
|
59
|
+
|
60
|
+
def _shared_to_object(self, shared: type):
|
61
|
+
if not self._shared:
|
62
|
+
return shared
|
63
|
+
if isinstance(shared, tuple) and len(shared):
|
64
|
+
if isinstance(shared[0], SharedMemory):
|
65
|
+
return be.from_sharedarr(shared)
|
66
|
+
return shared
|
67
|
+
|
68
|
+
def _thread_safe_call(self, state, *args, **kwargs):
|
69
|
+
"""Thread-safe call to analyzer"""
|
70
|
+
with self._lock:
|
71
|
+
state = tuple(self._shared_to_object(x) for x in state)
|
72
|
+
return self._direct_call(state, *args, **kwargs)
|
73
|
+
|
74
|
+
def _direct_call(self, state, *args, **kwargs):
|
75
|
+
"""Direct call to analyzer without locking"""
|
76
|
+
return self._analyzer(state, *args, **kwargs)
|
77
|
+
|
78
|
+
def result(self, state, **kwargs):
|
79
|
+
"""Extract final result"""
|
80
|
+
final_state = state
|
81
|
+
if self._shared:
|
82
|
+
# Convert shared arrays back to regular arrays and copy to
|
83
|
+
# avoid array invalidation by shared memory handler
|
84
|
+
final_state = tuple(self._shared_to_object(x) for x in final_state)
|
85
|
+
return self._analyzer.result(final_state, **kwargs)
|
86
|
+
|
87
|
+
def merge(self, *args, **kwargs):
|
88
|
+
return self._analyzer.merge(*args, **kwargs)
|
89
|
+
|
90
|
+
|
91
|
+
class SharedAnalyzerProxy(StatelessSharedAnalyzerProxy):
|
92
|
+
"""
|
93
|
+
Child of :py:class:`StatelessSharedAnalyzerProxy` that is aware
|
94
|
+
of the current analyzer state to emulate the previous analyzer interface.
|
95
|
+
"""
|
96
|
+
|
97
|
+
def __init__(
|
98
|
+
self,
|
99
|
+
analyzer_class: type,
|
100
|
+
analyzer_params: dict,
|
101
|
+
shm_handler: type = None,
|
102
|
+
**kwargs,
|
103
|
+
):
|
104
|
+
super().__init__(
|
105
|
+
analyzer_class=analyzer_class,
|
106
|
+
analyzer_params=analyzer_params,
|
107
|
+
)
|
108
|
+
if not self._analyzer.shareable:
|
109
|
+
shm_handler = None
|
110
|
+
self.init_state(shm_handler)
|
111
|
+
|
112
|
+
def init_state(self, shm_handler=None, *args, **kwargs) -> Tuple:
|
113
|
+
self._state = super().init_state(shm_handler, *args, **kwargs)
|
114
|
+
|
115
|
+
def __call__(self, *args, **kwargs):
|
116
|
+
state = super().__call__(self._state, *args, **kwargs)
|
117
|
+
if not self._shared:
|
118
|
+
self._state = state
|
119
|
+
|
120
|
+
def result(self, **kwargs):
|
121
|
+
"""Extract final result"""
|
122
|
+
return super().result(self._state, **kwargs)
|
tme/backends/__init__.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
pyTME backend manager.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
9
|
from typing import Dict, List
|