pytme 0.2.9.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.post1.data/scripts/estimate_ram_usage.py → pytme-0.3b0.data/scripts/estimate_memory_usage.py +16 -33
- {pytme-0.2.9.post1.data → pytme-0.3b0.data}/scripts/match_template.py +224 -223
- {pytme-0.2.9.post1.data → pytme-0.3b0.data}/scripts/postprocess.py +283 -163
- {pytme-0.2.9.post1.data → pytme-0.3b0.data}/scripts/preprocess.py +11 -8
- {pytme-0.2.9.post1.data → pytme-0.3b0.data}/scripts/preprocessor_gui.py +10 -9
- {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/METADATA +10 -9
- {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/RECORD +61 -58
- {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/entry_points.txt +1 -1
- scripts/{estimate_ram_usage.py → estimate_memory_usage.py} +16 -33
- scripts/extract_candidates.py +224 -0
- scripts/match_template.py +224 -223
- scripts/postprocess.py +283 -163
- scripts/preprocess.py +11 -8
- scripts/preprocessor_gui.py +10 -9
- scripts/refine_matches.py +626 -0
- tests/preprocessing/test_frequency_filters.py +9 -4
- tests/test_analyzer.py +143 -138
- tests/test_matching_cli.py +85 -30
- tests/test_matching_exhaustive.py +1 -2
- tests/test_matching_optimization.py +4 -9
- tests/test_orientations.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +25 -17
- tme/analyzer/aggregation.py +384 -220
- tme/analyzer/base.py +138 -0
- tme/analyzer/peaks.py +150 -91
- tme/analyzer/proxy.py +122 -0
- tme/backends/__init__.py +4 -3
- tme/backends/_cupy_utils.py +25 -24
- tme/backends/_jax_utils.py +4 -3
- tme/backends/cupy_backend.py +4 -13
- tme/backends/jax_backend.py +6 -8
- tme/backends/matching_backend.py +4 -3
- tme/backends/mlx_backend.py +4 -3
- tme/backends/npfftw_backend.py +7 -5
- tme/backends/pytorch_backend.py +14 -4
- tme/cli.py +126 -0
- tme/density.py +4 -3
- tme/filters/__init__.py +1 -1
- tme/filters/_utils.py +4 -3
- tme/filters/bandpass.py +6 -4
- tme/filters/compose.py +5 -4
- tme/filters/ctf.py +426 -214
- tme/filters/reconstruction.py +58 -28
- tme/filters/wedge.py +139 -61
- tme/filters/whitening.py +36 -36
- tme/matching_data.py +4 -3
- tme/matching_exhaustive.py +17 -16
- tme/matching_optimization.py +5 -4
- tme/matching_scores.py +4 -3
- tme/matching_utils.py +41 -3
- tme/memory.py +4 -3
- tme/orientations.py +9 -6
- tme/parser.py +5 -4
- tme/preprocessor.py +4 -3
- tme/rotations.py +10 -7
- tme/structure.py +4 -3
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/WHEEL +0 -0
- {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/top_level.txt +0 -0
tme/analyzer/base.py
ADDED
@@ -0,0 +1,138 @@
|
|
1
|
+
"""
|
2
|
+
Implements abstract base class for template matching analyzers.
|
3
|
+
|
4
|
+
Copyright (c) 2025 European Molecular Biology Laboratory
|
5
|
+
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
|
+
"""
|
8
|
+
|
9
|
+
from typing import Tuple, List
|
10
|
+
from abc import ABC, abstractmethod
|
11
|
+
|
12
|
+
__all__ = ["AbstractAnalyzer"]
|
13
|
+
|
14
|
+
|
15
|
+
class AbstractAnalyzer(ABC):
|
16
|
+
"""
|
17
|
+
Abstract base class for template matching analyzers.
|
18
|
+
"""
|
19
|
+
|
20
|
+
@property
|
21
|
+
def shareable(self):
|
22
|
+
"""
|
23
|
+
Indicate whether the analyzer can be shared across processes.
|
24
|
+
|
25
|
+
Returns
|
26
|
+
-------
|
27
|
+
bool
|
28
|
+
True if the analyzer supports shared memory operations
|
29
|
+
and can be safely used across multiple processes, False
|
30
|
+
if it should only be used within a single process.
|
31
|
+
"""
|
32
|
+
return False
|
33
|
+
|
34
|
+
@abstractmethod
|
35
|
+
def init_state(self, *args, **kwargs) -> Tuple:
|
36
|
+
"""
|
37
|
+
Initialize the analyzer state.
|
38
|
+
|
39
|
+
Returns
|
40
|
+
-------
|
41
|
+
state
|
42
|
+
Initial state tuple containing the analyzer's internal data
|
43
|
+
structures. The exact structure depends on the specific
|
44
|
+
implementation.
|
45
|
+
|
46
|
+
Notes
|
47
|
+
-----
|
48
|
+
This method creates the initial state that will be passed to
|
49
|
+
subsequent calls to __call__. The state should contain all
|
50
|
+
necessary data structures for accumulating analysis results.
|
51
|
+
"""
|
52
|
+
|
53
|
+
@abstractmethod
|
54
|
+
def __call__(self, state, scores, rotation_matrix, **kwargs) -> Tuple:
|
55
|
+
"""
|
56
|
+
Update the analyzer state with new scoring data.
|
57
|
+
|
58
|
+
Parameters
|
59
|
+
----------
|
60
|
+
state : object
|
61
|
+
Current analyzer state as returned by init_state() or
|
62
|
+
previous calls to __call__.
|
63
|
+
scores : BackendArray
|
64
|
+
Array of scores computed for the current rotation.
|
65
|
+
rotation_matrix : BackendArray
|
66
|
+
Rotation matrix used to generate the scores.
|
67
|
+
**kwargs : dict
|
68
|
+
Additional keyword arguments specific to the analyzer
|
69
|
+
implementation.
|
70
|
+
|
71
|
+
Returns
|
72
|
+
-------
|
73
|
+
state
|
74
|
+
Updated analyzer state with the new scoring data incorporated.
|
75
|
+
|
76
|
+
Notes
|
77
|
+
-----
|
78
|
+
This method should be pure functional - it should not modify
|
79
|
+
the input state but return a new state with the updates applied.
|
80
|
+
The exact signature may vary between implementations.
|
81
|
+
"""
|
82
|
+
pass
|
83
|
+
|
84
|
+
@abstractmethod
|
85
|
+
def result(self, state: Tuple, **kwargs) -> Tuple:
|
86
|
+
"""
|
87
|
+
Finalize the analysis and produce the final result.
|
88
|
+
|
89
|
+
Parameters
|
90
|
+
----------
|
91
|
+
state : tuple
|
92
|
+
Final analyzer state containing all accumulated data.
|
93
|
+
**kwargs : dict
|
94
|
+
Additional keyword arguments for result processing,
|
95
|
+
such as postprocessing parameters.
|
96
|
+
|
97
|
+
Returns
|
98
|
+
-------
|
99
|
+
result
|
100
|
+
Final analysis result. The exact format depends on the
|
101
|
+
analyzer implementation but typically includes processed
|
102
|
+
scores, rotation information, and metadata.
|
103
|
+
|
104
|
+
Notes
|
105
|
+
-----
|
106
|
+
This method converts the internal analyzer state into the
|
107
|
+
final output format expected by the template matching pipeline.
|
108
|
+
It may apply postprocessing operations like convolution mode
|
109
|
+
correction or coordinate transformations.
|
110
|
+
"""
|
111
|
+
pass
|
112
|
+
|
113
|
+
@classmethod
|
114
|
+
@abstractmethod
|
115
|
+
def merge(cls, results: List[Tuple], **kwargs) -> Tuple:
|
116
|
+
"""
|
117
|
+
Merge results from multiple analyzer instances.
|
118
|
+
|
119
|
+
Parameters
|
120
|
+
----------
|
121
|
+
results : list
|
122
|
+
List of result objects as returned by the result() method
|
123
|
+
from multiple analyzer instances.
|
124
|
+
**kwargs : dict
|
125
|
+
Additional keyword arguments for merge configuration.
|
126
|
+
|
127
|
+
Returns
|
128
|
+
-------
|
129
|
+
merged_result
|
130
|
+
Single result object combining all input results.
|
131
|
+
|
132
|
+
Notes
|
133
|
+
-----
|
134
|
+
This method enables parallel processing by allowing results
|
135
|
+
from different processes or splits to be combined into a
|
136
|
+
unified result. The merge operation should handle overlapping
|
137
|
+
data appropriately and maintain consistency.
|
138
|
+
"""
|
tme/analyzer/peaks.py
CHANGED
@@ -1,18 +1,20 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Implements classes to analyze outputs from exhaustive template matching.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
9
|
from functools import wraps
|
9
|
-
from abc import
|
10
|
-
from typing import Tuple, List, Dict
|
10
|
+
from abc import abstractmethod
|
11
|
+
from typing import Tuple, List, Dict
|
11
12
|
|
12
13
|
import numpy as np
|
13
14
|
from skimage.feature import peak_local_max
|
14
15
|
from skimage.registration._phase_cross_correlation import _upsampled_dft
|
15
16
|
|
17
|
+
from .base import AbstractAnalyzer
|
16
18
|
from ._utils import score_to_cart
|
17
19
|
from ..backends import backend as be
|
18
20
|
from ..matching_utils import split_shape
|
@@ -141,7 +143,7 @@ def batchify(shape: Tuple[int], batch_dims: Tuple[int] = None) -> List:
|
|
141
143
|
yield from _generate_slices_recursive(0, ())
|
142
144
|
|
143
145
|
|
144
|
-
class PeakCaller(
|
146
|
+
class PeakCaller(AbstractAnalyzer):
|
145
147
|
"""
|
146
148
|
Base class for peak calling algorithms.
|
147
149
|
|
@@ -190,19 +192,7 @@ class PeakCaller(ABC):
|
|
190
192
|
if min_boundary_distance < 0:
|
191
193
|
raise ValueError("min_boundary_distance has to be non-negative.")
|
192
194
|
|
193
|
-
|
194
|
-
self.translations = be.full(
|
195
|
-
(num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
|
196
|
-
)
|
197
|
-
self.rotations = be.full(
|
198
|
-
(num_peaks, ndim, ndim), fill_value=0, dtype=be._float_dtype
|
199
|
-
)
|
200
|
-
for i in range(ndim):
|
201
|
-
self.rotations[:, i, i] = 1.0
|
202
|
-
|
203
|
-
self.scores = be.full((num_peaks,), fill_value=0, dtype=be._float_dtype)
|
204
|
-
self.details = be.full((num_peaks,), fill_value=0, dtype=be._float_dtype)
|
205
|
-
|
195
|
+
self.shape = shape
|
206
196
|
self.num_peaks = int(num_peaks)
|
207
197
|
self.min_distance = int(min_distance)
|
208
198
|
self.min_boundary_distance = int(min_boundary_distance)
|
@@ -213,31 +203,47 @@ class PeakCaller(ABC):
|
|
213
203
|
|
214
204
|
self.min_score, self.max_score = min_score, max_score
|
215
205
|
|
216
|
-
|
217
|
-
|
218
|
-
self.convolution_mode = kwargs.get("convolution_mode", None)
|
219
|
-
self.targetshape = kwargs.get("targetshape", None)
|
220
|
-
self.templateshape = kwargs.get("templateshape", None)
|
221
|
-
|
222
|
-
def __iter__(self) -> Generator:
|
206
|
+
@abstractmethod
|
207
|
+
def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
|
223
208
|
"""
|
224
|
-
|
225
|
-
|
209
|
+
Call peaks in the score space.
|
210
|
+
|
211
|
+
Parameters
|
212
|
+
----------
|
213
|
+
scores : BackendArray
|
214
|
+
Score array to update analyzer with.
|
215
|
+
**kwargs : dict
|
216
|
+
Optional keyword arguments passed to underlying implementations.
|
217
|
+
|
218
|
+
Returns
|
219
|
+
-------
|
220
|
+
BackendArray
|
221
|
+
Peak positions (n, d).
|
222
|
+
BackendArray
|
223
|
+
Peak details (n, d).
|
226
224
|
"""
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
225
|
+
|
226
|
+
def init_state(self):
|
227
|
+
ndim = len(self.shape)
|
228
|
+
translations = be.full(
|
229
|
+
(self.num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
|
230
|
+
)
|
231
|
+
rotations = be.full(
|
232
|
+
(self.num_peaks, ndim, ndim), fill_value=0, dtype=be._float_dtype
|
233
|
+
)
|
234
|
+
for i in range(ndim):
|
235
|
+
rotations[:, i, i] = 1.0
|
236
|
+
|
237
|
+
scores = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
|
238
|
+
details = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
|
239
|
+
return translations, rotations, scores, details
|
234
240
|
|
235
241
|
def _get_peak_mask(self, peaks: BackendArray, scores: BackendArray) -> BackendArray:
|
236
242
|
if not len(peaks):
|
237
243
|
return None
|
238
244
|
|
239
245
|
valid_peaks = be.full((peaks.shape[0],), fill_value=1) == 1
|
240
|
-
if self.min_boundary_distance
|
246
|
+
if self.min_boundary_distance >= 0:
|
241
247
|
upper_limit = be.subtract(
|
242
248
|
be.to_backend_array(scores.shape), self.min_boundary_distance
|
243
249
|
)
|
@@ -318,20 +324,34 @@ class PeakCaller(ABC):
|
|
318
324
|
peak_positions = be.astype(peak_positions, int)
|
319
325
|
return peak_positions, peak_details
|
320
326
|
|
321
|
-
def __call__(
|
327
|
+
def __call__(
|
328
|
+
self,
|
329
|
+
state: Tuple,
|
330
|
+
scores: BackendArray,
|
331
|
+
rotation_matrix: BackendArray,
|
332
|
+
**kwargs,
|
333
|
+
) -> Tuple:
|
322
334
|
"""
|
323
335
|
Update the internal parameter store based on input array.
|
324
336
|
|
325
337
|
Parameters
|
326
338
|
----------
|
339
|
+
state : tuple
|
340
|
+
Current state tuple where:
|
341
|
+
- positions : BackendArray, (n, d) of peak positions
|
342
|
+
- rotations : BackendArray, (n, d, d) of correponding rotations
|
343
|
+
- scores : BackendArray, (n, ) of peak scores
|
344
|
+
- details : BackendArray, (n, ) of peak details
|
327
345
|
scores : BackendArray
|
328
|
-
|
346
|
+
Array of new scores to update analyzer with.
|
329
347
|
rotation_matrix : BackendArray
|
330
348
|
Rotation matrix used to obtain the score array.
|
331
349
|
**kwargs
|
332
350
|
Optional keyword aguments passed to :py:meth:`PeakCaller.call_peaks`.
|
333
351
|
"""
|
334
|
-
for ret in self._call_peaks(
|
352
|
+
for ret in self._call_peaks(
|
353
|
+
scores=scores, rotation_matrix=rotation_matrix, **kwargs
|
354
|
+
):
|
335
355
|
peak_positions, peak_details = ret
|
336
356
|
if peak_positions is None:
|
337
357
|
continue
|
@@ -344,7 +364,6 @@ class PeakCaller(ABC):
|
|
344
364
|
peak_scores = scores[tuple(peak_positions.T)]
|
345
365
|
if peak_details is not None:
|
346
366
|
peak_details = peak_details[valid_peaks]
|
347
|
-
# peak_details, peak_scores = peak_scores, -peak_details
|
348
367
|
else:
|
349
368
|
peak_details = be.full(peak_scores.shape, fill_value=-1)
|
350
369
|
|
@@ -354,66 +373,57 @@ class PeakCaller(ABC):
|
|
354
373
|
axis=0,
|
355
374
|
)
|
356
375
|
|
357
|
-
self._update(
|
376
|
+
state = self._update(
|
377
|
+
state,
|
358
378
|
peak_positions=peak_positions,
|
359
379
|
peak_details=peak_details,
|
360
380
|
peak_scores=peak_scores,
|
361
|
-
|
381
|
+
peak_rotations=rotations,
|
362
382
|
)
|
363
383
|
|
364
|
-
return
|
365
|
-
|
366
|
-
@abstractmethod
|
367
|
-
def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
|
368
|
-
"""
|
369
|
-
Call peaks in the score space.
|
370
|
-
|
371
|
-
Parameters
|
372
|
-
----------
|
373
|
-
scores : BackendArray
|
374
|
-
Score array.
|
375
|
-
**kwargs : dict
|
376
|
-
Optional keyword arguments passed to underlying implementations.
|
377
|
-
|
378
|
-
Returns
|
379
|
-
-------
|
380
|
-
Tuple[BackendArray, BackendArray]
|
381
|
-
Array of peak coordinates and peak details.
|
382
|
-
"""
|
384
|
+
return state
|
383
385
|
|
384
386
|
@classmethod
|
385
|
-
def merge(cls,
|
387
|
+
def merge(cls, results=List[Tuple], **kwargs) -> Tuple:
|
386
388
|
"""
|
387
389
|
Merge multiple instances of :py:class:`PeakCaller`.
|
388
390
|
|
389
391
|
Parameters
|
390
392
|
----------
|
391
|
-
|
392
|
-
|
393
|
+
results : list of tuple
|
394
|
+
List of instance results created by applying `result`.
|
393
395
|
**kwargs
|
394
396
|
Optional keyword arguments.
|
395
397
|
|
396
398
|
Returns
|
397
399
|
-------
|
398
|
-
|
399
|
-
|
400
|
+
NDArray
|
401
|
+
Peak positions (n, d).
|
402
|
+
NDArray
|
403
|
+
Peak rotation matrices (n, d, d).
|
404
|
+
NDArray
|
405
|
+
Peak scores (n, ).
|
406
|
+
NDArray
|
407
|
+
Peak details (n,).
|
400
408
|
"""
|
401
409
|
if "shape" not in kwargs:
|
402
|
-
kwargs["shape"] = tuple(1 for _ in range(
|
410
|
+
kwargs["shape"] = tuple(1 for _ in range(results[0][0].shape[1]))
|
403
411
|
|
404
412
|
base = cls(**kwargs)
|
405
|
-
|
406
|
-
|
413
|
+
base_state = base.init_state()
|
414
|
+
for result in results:
|
415
|
+
if len(result) == 0:
|
407
416
|
continue
|
408
|
-
peak_positions, rotations, peak_scores, peak_details =
|
409
|
-
base._update(
|
417
|
+
peak_positions, rotations, peak_scores, peak_details = result
|
418
|
+
base_state = base._update(
|
419
|
+
base_state,
|
410
420
|
peak_positions=be.to_backend_array(peak_positions),
|
411
421
|
peak_details=be.to_backend_array(peak_details),
|
412
422
|
peak_scores=be.to_backend_array(peak_scores),
|
413
|
-
|
423
|
+
peak_rotations=be.to_backend_array(rotations),
|
414
424
|
offset=kwargs.get("offset", None),
|
415
425
|
)
|
416
|
-
return
|
426
|
+
return base_state
|
417
427
|
|
418
428
|
@staticmethod
|
419
429
|
def oversample_peaks(
|
@@ -520,10 +530,11 @@ class PeakCaller(ABC):
|
|
520
530
|
|
521
531
|
def _update(
|
522
532
|
self,
|
533
|
+
state,
|
523
534
|
peak_positions: BackendArray,
|
524
535
|
peak_details: BackendArray,
|
525
536
|
peak_scores: BackendArray,
|
526
|
-
|
537
|
+
peak_rotations: BackendArray,
|
527
538
|
offset: BackendArray = None,
|
528
539
|
):
|
529
540
|
"""
|
@@ -542,14 +553,15 @@ class PeakCaller(ABC):
|
|
542
553
|
offset : BackendArray, optional
|
543
554
|
Translation offset, e.g. from splitting, (d, ).
|
544
555
|
"""
|
556
|
+
translations, rotations, scores, details = state
|
545
557
|
if offset is not None:
|
546
558
|
offset = be.astype(be.to_backend_array(offset), peak_positions.dtype)
|
547
559
|
peak_positions = be.add(peak_positions, offset, out=peak_positions)
|
548
560
|
|
549
|
-
positions = be.concatenate((
|
550
|
-
rotations = be.concatenate((
|
551
|
-
scores = be.concatenate((
|
552
|
-
details = be.concatenate((
|
561
|
+
positions = be.concatenate((translations, peak_positions))
|
562
|
+
rotations = be.concatenate((rotations, peak_rotations))
|
563
|
+
scores = be.concatenate((scores, peak_scores))
|
564
|
+
details = be.concatenate((details, peak_details))
|
553
565
|
|
554
566
|
# topk filtering after distances yields more distributed peak calls
|
555
567
|
distance_order = filter_points_indices(
|
@@ -564,22 +576,69 @@ class PeakCaller(ABC):
|
|
564
576
|
)
|
565
577
|
final_order = distance_order[top_scores]
|
566
578
|
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
579
|
+
translations = positions[final_order, :]
|
580
|
+
rotations = rotations[final_order, :]
|
581
|
+
scores = scores[final_order]
|
582
|
+
details = details[final_order]
|
583
|
+
return translations, rotations, scores, details
|
571
584
|
|
572
|
-
def
|
573
|
-
|
574
|
-
|
585
|
+
def result(
|
586
|
+
self,
|
587
|
+
state,
|
588
|
+
fast_shape: Tuple[int] = None,
|
589
|
+
targetshape: Tuple[int] = None,
|
590
|
+
templateshape: Tuple[int] = None,
|
591
|
+
convolution_shape: Tuple[int] = None,
|
592
|
+
fourier_shift: Tuple[int] = None,
|
593
|
+
convolution_mode: str = None,
|
594
|
+
**kwargs,
|
595
|
+
):
|
596
|
+
"""
|
597
|
+
Finalize the analysis result with optional postprocessing.
|
575
598
|
|
576
|
-
|
599
|
+
Parameters
|
600
|
+
----------
|
601
|
+
state : tuple
|
602
|
+
Current state tuple where:
|
603
|
+
- positions : BackendArray, (n, d) of peak positions
|
604
|
+
- rotations : BackendArray, (n, d, d) of correponding rotations
|
605
|
+
- scores : BackendArray, (n, ) of peak scores
|
606
|
+
- details : BackendArray, (n, ) of peak details
|
607
|
+
targetshape : Tuple[int], optional
|
608
|
+
Shape of the target for convolution mode correction.
|
609
|
+
templateshape : Tuple[int], optional
|
610
|
+
Shape of the template for convolution mode correction.
|
611
|
+
convolution_shape : Tuple[int], optional
|
612
|
+
Shape used for convolution.
|
613
|
+
fourier_shift : Tuple[int], optional.
|
614
|
+
Shift to apply for Fourier correction.
|
615
|
+
convolution_mode : str, optional
|
616
|
+
Convolution mode for padding correction.
|
617
|
+
**kwargs
|
618
|
+
Additional keyword arguments.
|
577
619
|
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
620
|
+
Returns
|
621
|
+
-------
|
622
|
+
tuple
|
623
|
+
Final result tuple (positions, rotations, scores, details).
|
624
|
+
"""
|
625
|
+
translations, rotations, scores, details = state
|
626
|
+
|
627
|
+
positions, valid_peaks = score_to_cart(
|
628
|
+
positions=translations,
|
629
|
+
fast_shape=fast_shape,
|
630
|
+
targetshape=targetshape,
|
631
|
+
templateshape=templateshape,
|
632
|
+
convolution_shape=convolution_shape,
|
633
|
+
fourier_shift=fourier_shift,
|
634
|
+
convolution_mode=convolution_mode,
|
635
|
+
**kwargs,
|
636
|
+
)
|
637
|
+
translations = be.to_cpu_array(positions[valid_peaks])
|
638
|
+
rotations = be.to_cpu_array(rotations[valid_peaks])
|
639
|
+
scores = be.to_cpu_array(scores[valid_peaks])
|
640
|
+
details = be.to_cpu_array(details[valid_peaks])
|
641
|
+
return translations, rotations, scores, details
|
583
642
|
|
584
643
|
|
585
644
|
class PeakCallerSort(PeakCaller):
|
tme/analyzer/proxy.py
ADDED
@@ -0,0 +1,122 @@
|
|
1
|
+
"""
|
2
|
+
Implements SharedAnalyzerProxy to managed shared memory of Analyzer instances
|
3
|
+
across different tasks.
|
4
|
+
|
5
|
+
This is primarily useful for CPU template matching, where parallelization can
|
6
|
+
be performed over rotations, rather than subsections of a large input volume.
|
7
|
+
|
8
|
+
Copyright (c) 2025 European Molecular Biology Laboratory
|
9
|
+
|
10
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
11
|
+
"""
|
12
|
+
|
13
|
+
from typing import Tuple
|
14
|
+
from multiprocessing import Manager
|
15
|
+
from multiprocessing.shared_memory import SharedMemory
|
16
|
+
|
17
|
+
from ..backends import backend as be
|
18
|
+
|
19
|
+
__all__ = ["StatelessSharedAnalyzerProxy", "SharedAnalyzerProxy"]
|
20
|
+
|
21
|
+
class StatelessSharedAnalyzerProxy:
|
22
|
+
"""
|
23
|
+
Proxy that wraps functional analyzers for concurrent access via shared memory.
|
24
|
+
|
25
|
+
Enables multiple processes/threads to safely update the same analyzer
|
26
|
+
while preserving the functional interface of the underlying analyzer.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(self, analyzer_class: type, analyzer_params: dict):
|
30
|
+
self._shared = False
|
31
|
+
self._process = self._direct_call
|
32
|
+
|
33
|
+
self._analyzer = analyzer_class(**analyzer_params)
|
34
|
+
|
35
|
+
def __call__(self, state, *args, **kwargs):
|
36
|
+
return self._process(state, *args, **kwargs)
|
37
|
+
|
38
|
+
def init_state(self, shm_handler=None, *args, **kwargs) -> Tuple:
|
39
|
+
state = self._analyzer.init_state()
|
40
|
+
if shm_handler is not None:
|
41
|
+
self._shared = True
|
42
|
+
state = self._to_shared(state, shm_handler)
|
43
|
+
|
44
|
+
self._lock = Manager().Lock()
|
45
|
+
self._process = self._thread_safe_call
|
46
|
+
return state
|
47
|
+
|
48
|
+
def _to_shared(self, state: Tuple, shm_handler):
|
49
|
+
backend_arr = type(be.zeros((1), dtype=be._float_dtype))
|
50
|
+
|
51
|
+
ret = []
|
52
|
+
for v in state:
|
53
|
+
if isinstance(v, backend_arr):
|
54
|
+
v = be.to_sharedarr(v, shm_handler)
|
55
|
+
elif isinstance(v, dict):
|
56
|
+
v = Manager().dict(**v)
|
57
|
+
ret.append(v)
|
58
|
+
return tuple(ret)
|
59
|
+
|
60
|
+
def _shared_to_object(self, shared: type):
|
61
|
+
if not self._shared:
|
62
|
+
return shared
|
63
|
+
if isinstance(shared, tuple) and len(shared):
|
64
|
+
if isinstance(shared[0], SharedMemory):
|
65
|
+
return be.from_sharedarr(shared)
|
66
|
+
return shared
|
67
|
+
|
68
|
+
def _thread_safe_call(self, state, *args, **kwargs):
|
69
|
+
"""Thread-safe call to analyzer"""
|
70
|
+
with self._lock:
|
71
|
+
state = tuple(self._shared_to_object(x) for x in state)
|
72
|
+
return self._direct_call(state, *args, **kwargs)
|
73
|
+
|
74
|
+
def _direct_call(self, state, *args, **kwargs):
|
75
|
+
"""Direct call to analyzer without locking"""
|
76
|
+
return self._analyzer(state, *args, **kwargs)
|
77
|
+
|
78
|
+
def result(self, state, **kwargs):
|
79
|
+
"""Extract final result"""
|
80
|
+
final_state = state
|
81
|
+
if self._shared:
|
82
|
+
# Convert shared arrays back to regular arrays and copy to
|
83
|
+
# avoid array invalidation by shared memory handler
|
84
|
+
final_state = tuple(self._shared_to_object(x) for x in final_state)
|
85
|
+
return self._analyzer.result(final_state, **kwargs)
|
86
|
+
|
87
|
+
def merge(self, *args, **kwargs):
|
88
|
+
return self._analyzer.merge(*args, **kwargs)
|
89
|
+
|
90
|
+
|
91
|
+
class SharedAnalyzerProxy(StatelessSharedAnalyzerProxy):
|
92
|
+
"""
|
93
|
+
Child of :py:class:`StatelessSharedAnalyzerProxy` that is aware
|
94
|
+
of the current analyzer state to emulate the previous analyzer interface.
|
95
|
+
"""
|
96
|
+
|
97
|
+
def __init__(
|
98
|
+
self,
|
99
|
+
analyzer_class: type,
|
100
|
+
analyzer_params: dict,
|
101
|
+
shm_handler: type = None,
|
102
|
+
**kwargs,
|
103
|
+
):
|
104
|
+
super().__init__(
|
105
|
+
analyzer_class=analyzer_class,
|
106
|
+
analyzer_params=analyzer_params,
|
107
|
+
)
|
108
|
+
if not self._analyzer.shareable:
|
109
|
+
shm_handler = None
|
110
|
+
self.init_state(shm_handler)
|
111
|
+
|
112
|
+
def init_state(self, shm_handler=None, *args, **kwargs) -> Tuple:
|
113
|
+
self._state = super().init_state(shm_handler, *args, **kwargs)
|
114
|
+
|
115
|
+
def __call__(self, *args, **kwargs):
|
116
|
+
state = super().__call__(self._state, *args, **kwargs)
|
117
|
+
if not self._shared:
|
118
|
+
self._state = state
|
119
|
+
|
120
|
+
def result(self, **kwargs):
|
121
|
+
"""Extract final result"""
|
122
|
+
return super().result(self._state, **kwargs)
|
tme/backends/__init__.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
pyTME backend manager.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
9
|
from typing import Dict, List
|