pytme 0.2.9.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. pytme-0.2.9.post1.data/scripts/estimate_ram_usage.py → pytme-0.3b0.data/scripts/estimate_memory_usage.py +16 -33
  2. {pytme-0.2.9.post1.data → pytme-0.3b0.data}/scripts/match_template.py +224 -223
  3. {pytme-0.2.9.post1.data → pytme-0.3b0.data}/scripts/postprocess.py +283 -163
  4. {pytme-0.2.9.post1.data → pytme-0.3b0.data}/scripts/preprocess.py +11 -8
  5. {pytme-0.2.9.post1.data → pytme-0.3b0.data}/scripts/preprocessor_gui.py +10 -9
  6. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/METADATA +10 -9
  7. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/RECORD +61 -58
  8. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/entry_points.txt +1 -1
  9. scripts/{estimate_ram_usage.py → estimate_memory_usage.py} +16 -33
  10. scripts/extract_candidates.py +224 -0
  11. scripts/match_template.py +224 -223
  12. scripts/postprocess.py +283 -163
  13. scripts/preprocess.py +11 -8
  14. scripts/preprocessor_gui.py +10 -9
  15. scripts/refine_matches.py +626 -0
  16. tests/preprocessing/test_frequency_filters.py +9 -4
  17. tests/test_analyzer.py +143 -138
  18. tests/test_matching_cli.py +85 -30
  19. tests/test_matching_exhaustive.py +1 -2
  20. tests/test_matching_optimization.py +4 -9
  21. tests/test_orientations.py +0 -1
  22. tme/__version__.py +1 -1
  23. tme/analyzer/__init__.py +2 -0
  24. tme/analyzer/_utils.py +25 -17
  25. tme/analyzer/aggregation.py +384 -220
  26. tme/analyzer/base.py +138 -0
  27. tme/analyzer/peaks.py +150 -91
  28. tme/analyzer/proxy.py +122 -0
  29. tme/backends/__init__.py +4 -3
  30. tme/backends/_cupy_utils.py +25 -24
  31. tme/backends/_jax_utils.py +4 -3
  32. tme/backends/cupy_backend.py +4 -13
  33. tme/backends/jax_backend.py +6 -8
  34. tme/backends/matching_backend.py +4 -3
  35. tme/backends/mlx_backend.py +4 -3
  36. tme/backends/npfftw_backend.py +7 -5
  37. tme/backends/pytorch_backend.py +14 -4
  38. tme/cli.py +126 -0
  39. tme/density.py +4 -3
  40. tme/filters/__init__.py +1 -1
  41. tme/filters/_utils.py +4 -3
  42. tme/filters/bandpass.py +6 -4
  43. tme/filters/compose.py +5 -4
  44. tme/filters/ctf.py +426 -214
  45. tme/filters/reconstruction.py +58 -28
  46. tme/filters/wedge.py +139 -61
  47. tme/filters/whitening.py +36 -36
  48. tme/matching_data.py +4 -3
  49. tme/matching_exhaustive.py +17 -16
  50. tme/matching_optimization.py +5 -4
  51. tme/matching_scores.py +4 -3
  52. tme/matching_utils.py +41 -3
  53. tme/memory.py +4 -3
  54. tme/orientations.py +9 -6
  55. tme/parser.py +5 -4
  56. tme/preprocessor.py +4 -3
  57. tme/rotations.py +10 -7
  58. tme/structure.py +4 -3
  59. tests/data/Maps/.DS_Store +0 -0
  60. tests/data/Structures/.DS_Store +0 -0
  61. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/WHEEL +0 -0
  62. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/licenses/LICENSE +0 -0
  63. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/top_level.txt +0 -0
tme/analyzer/base.py ADDED
@@ -0,0 +1,138 @@
1
+ """
2
+ Implements abstract base class for template matching analyzers.
3
+
4
+ Copyright (c) 2025 European Molecular Biology Laboratory
5
+
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ """
8
+
9
+ from typing import Tuple, List
10
+ from abc import ABC, abstractmethod
11
+
12
+ __all__ = ["AbstractAnalyzer"]
13
+
14
+
15
+ class AbstractAnalyzer(ABC):
16
+ """
17
+ Abstract base class for template matching analyzers.
18
+ """
19
+
20
+ @property
21
+ def shareable(self):
22
+ """
23
+ Indicate whether the analyzer can be shared across processes.
24
+
25
+ Returns
26
+ -------
27
+ bool
28
+ True if the analyzer supports shared memory operations
29
+ and can be safely used across multiple processes, False
30
+ if it should only be used within a single process.
31
+ """
32
+ return False
33
+
34
+ @abstractmethod
35
+ def init_state(self, *args, **kwargs) -> Tuple:
36
+ """
37
+ Initialize the analyzer state.
38
+
39
+ Returns
40
+ -------
41
+ state
42
+ Initial state tuple containing the analyzer's internal data
43
+ structures. The exact structure depends on the specific
44
+ implementation.
45
+
46
+ Notes
47
+ -----
48
+ This method creates the initial state that will be passed to
49
+ subsequent calls to __call__. The state should contain all
50
+ necessary data structures for accumulating analysis results.
51
+ """
52
+
53
+ @abstractmethod
54
+ def __call__(self, state, scores, rotation_matrix, **kwargs) -> Tuple:
55
+ """
56
+ Update the analyzer state with new scoring data.
57
+
58
+ Parameters
59
+ ----------
60
+ state : object
61
+ Current analyzer state as returned by init_state() or
62
+ previous calls to __call__.
63
+ scores : BackendArray
64
+ Array of scores computed for the current rotation.
65
+ rotation_matrix : BackendArray
66
+ Rotation matrix used to generate the scores.
67
+ **kwargs : dict
68
+ Additional keyword arguments specific to the analyzer
69
+ implementation.
70
+
71
+ Returns
72
+ -------
73
+ state
74
+ Updated analyzer state with the new scoring data incorporated.
75
+
76
+ Notes
77
+ -----
78
+ This method should be pure functional - it should not modify
79
+ the input state but return a new state with the updates applied.
80
+ The exact signature may vary between implementations.
81
+ """
82
+ pass
83
+
84
+ @abstractmethod
85
+ def result(self, state: Tuple, **kwargs) -> Tuple:
86
+ """
87
+ Finalize the analysis and produce the final result.
88
+
89
+ Parameters
90
+ ----------
91
+ state : tuple
92
+ Final analyzer state containing all accumulated data.
93
+ **kwargs : dict
94
+ Additional keyword arguments for result processing,
95
+ such as postprocessing parameters.
96
+
97
+ Returns
98
+ -------
99
+ result
100
+ Final analysis result. The exact format depends on the
101
+ analyzer implementation but typically includes processed
102
+ scores, rotation information, and metadata.
103
+
104
+ Notes
105
+ -----
106
+ This method converts the internal analyzer state into the
107
+ final output format expected by the template matching pipeline.
108
+ It may apply postprocessing operations like convolution mode
109
+ correction or coordinate transformations.
110
+ """
111
+ pass
112
+
113
+ @classmethod
114
+ @abstractmethod
115
+ def merge(cls, results: List[Tuple], **kwargs) -> Tuple:
116
+ """
117
+ Merge results from multiple analyzer instances.
118
+
119
+ Parameters
120
+ ----------
121
+ results : list
122
+ List of result objects as returned by the result() method
123
+ from multiple analyzer instances.
124
+ **kwargs : dict
125
+ Additional keyword arguments for merge configuration.
126
+
127
+ Returns
128
+ -------
129
+ merged_result
130
+ Single result object combining all input results.
131
+
132
+ Notes
133
+ -----
134
+ This method enables parallel processing by allowing results
135
+ from different processes or splits to be combined into a
136
+ unified result. The merge operation should handle overlapping
137
+ data appropriately and maintain consistency.
138
+ """
tme/analyzer/peaks.py CHANGED
@@ -1,18 +1,20 @@
1
- """ Implements classes to analyze outputs from exhaustive template matching.
1
+ """
2
+ Implements classes to analyze outputs from exhaustive template matching.
2
3
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  from functools import wraps
9
- from abc import ABC, abstractmethod
10
- from typing import Tuple, List, Dict, Generator
10
+ from abc import abstractmethod
11
+ from typing import Tuple, List, Dict
11
12
 
12
13
  import numpy as np
13
14
  from skimage.feature import peak_local_max
14
15
  from skimage.registration._phase_cross_correlation import _upsampled_dft
15
16
 
17
+ from .base import AbstractAnalyzer
16
18
  from ._utils import score_to_cart
17
19
  from ..backends import backend as be
18
20
  from ..matching_utils import split_shape
@@ -141,7 +143,7 @@ def batchify(shape: Tuple[int], batch_dims: Tuple[int] = None) -> List:
141
143
  yield from _generate_slices_recursive(0, ())
142
144
 
143
145
 
144
- class PeakCaller(ABC):
146
+ class PeakCaller(AbstractAnalyzer):
145
147
  """
146
148
  Base class for peak calling algorithms.
147
149
 
@@ -190,19 +192,7 @@ class PeakCaller(ABC):
190
192
  if min_boundary_distance < 0:
191
193
  raise ValueError("min_boundary_distance has to be non-negative.")
192
194
 
193
- ndim = len(shape)
194
- self.translations = be.full(
195
- (num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
196
- )
197
- self.rotations = be.full(
198
- (num_peaks, ndim, ndim), fill_value=0, dtype=be._float_dtype
199
- )
200
- for i in range(ndim):
201
- self.rotations[:, i, i] = 1.0
202
-
203
- self.scores = be.full((num_peaks,), fill_value=0, dtype=be._float_dtype)
204
- self.details = be.full((num_peaks,), fill_value=0, dtype=be._float_dtype)
205
-
195
+ self.shape = shape
206
196
  self.num_peaks = int(num_peaks)
207
197
  self.min_distance = int(min_distance)
208
198
  self.min_boundary_distance = int(min_boundary_distance)
@@ -213,31 +203,47 @@ class PeakCaller(ABC):
213
203
 
214
204
  self.min_score, self.max_score = min_score, max_score
215
205
 
216
- # Postprocessing arguments
217
- self.fourier_shift = kwargs.get("fourier_shift", None)
218
- self.convolution_mode = kwargs.get("convolution_mode", None)
219
- self.targetshape = kwargs.get("targetshape", None)
220
- self.templateshape = kwargs.get("templateshape", None)
221
-
222
- def __iter__(self) -> Generator:
206
+ @abstractmethod
207
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
223
208
  """
224
- Returns a generator to list objects containing translation,
225
- rotation, score and details of a given candidate.
209
+ Call peaks in the score space.
210
+
211
+ Parameters
212
+ ----------
213
+ scores : BackendArray
214
+ Score array to update analyzer with.
215
+ **kwargs : dict
216
+ Optional keyword arguments passed to underlying implementations.
217
+
218
+ Returns
219
+ -------
220
+ BackendArray
221
+ Peak positions (n, d).
222
+ BackendArray
223
+ Peak details (n, d).
226
224
  """
227
- self.peak_list = [
228
- be.to_cpu_array(self.translations),
229
- be.to_cpu_array(self.rotations),
230
- be.to_cpu_array(self.scores),
231
- be.to_cpu_array(self.details),
232
- ]
233
- yield from self.peak_list
225
+
226
+ def init_state(self):
227
+ ndim = len(self.shape)
228
+ translations = be.full(
229
+ (self.num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
230
+ )
231
+ rotations = be.full(
232
+ (self.num_peaks, ndim, ndim), fill_value=0, dtype=be._float_dtype
233
+ )
234
+ for i in range(ndim):
235
+ rotations[:, i, i] = 1.0
236
+
237
+ scores = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
238
+ details = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
239
+ return translations, rotations, scores, details
234
240
 
235
241
  def _get_peak_mask(self, peaks: BackendArray, scores: BackendArray) -> BackendArray:
236
242
  if not len(peaks):
237
243
  return None
238
244
 
239
245
  valid_peaks = be.full((peaks.shape[0],), fill_value=1) == 1
240
- if self.min_boundary_distance > 0:
246
+ if self.min_boundary_distance >= 0:
241
247
  upper_limit = be.subtract(
242
248
  be.to_backend_array(scores.shape), self.min_boundary_distance
243
249
  )
@@ -318,20 +324,34 @@ class PeakCaller(ABC):
318
324
  peak_positions = be.astype(peak_positions, int)
319
325
  return peak_positions, peak_details
320
326
 
321
- def __call__(self, scores: BackendArray, rotation_matrix: BackendArray, **kwargs):
327
+ def __call__(
328
+ self,
329
+ state: Tuple,
330
+ scores: BackendArray,
331
+ rotation_matrix: BackendArray,
332
+ **kwargs,
333
+ ) -> Tuple:
322
334
  """
323
335
  Update the internal parameter store based on input array.
324
336
 
325
337
  Parameters
326
338
  ----------
339
+ state : tuple
340
+ Current state tuple where:
341
+ - positions : BackendArray, (n, d) of peak positions
342
+ - rotations : BackendArray, (n, d, d) of correponding rotations
343
+ - scores : BackendArray, (n, ) of peak scores
344
+ - details : BackendArray, (n, ) of peak details
327
345
  scores : BackendArray
328
- Score space data.
346
+ Array of new scores to update analyzer with.
329
347
  rotation_matrix : BackendArray
330
348
  Rotation matrix used to obtain the score array.
331
349
  **kwargs
332
350
  Optional keyword aguments passed to :py:meth:`PeakCaller.call_peaks`.
333
351
  """
334
- for ret in self._call_peaks(scores=scores, rotation_matrix=rotation_matrix):
352
+ for ret in self._call_peaks(
353
+ scores=scores, rotation_matrix=rotation_matrix, **kwargs
354
+ ):
335
355
  peak_positions, peak_details = ret
336
356
  if peak_positions is None:
337
357
  continue
@@ -344,7 +364,6 @@ class PeakCaller(ABC):
344
364
  peak_scores = scores[tuple(peak_positions.T)]
345
365
  if peak_details is not None:
346
366
  peak_details = peak_details[valid_peaks]
347
- # peak_details, peak_scores = peak_scores, -peak_details
348
367
  else:
349
368
  peak_details = be.full(peak_scores.shape, fill_value=-1)
350
369
 
@@ -354,66 +373,57 @@ class PeakCaller(ABC):
354
373
  axis=0,
355
374
  )
356
375
 
357
- self._update(
376
+ state = self._update(
377
+ state,
358
378
  peak_positions=peak_positions,
359
379
  peak_details=peak_details,
360
380
  peak_scores=peak_scores,
361
- rotations=rotations,
381
+ peak_rotations=rotations,
362
382
  )
363
383
 
364
- return None
365
-
366
- @abstractmethod
367
- def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
368
- """
369
- Call peaks in the score space.
370
-
371
- Parameters
372
- ----------
373
- scores : BackendArray
374
- Score array.
375
- **kwargs : dict
376
- Optional keyword arguments passed to underlying implementations.
377
-
378
- Returns
379
- -------
380
- Tuple[BackendArray, BackendArray]
381
- Array of peak coordinates and peak details.
382
- """
384
+ return state
383
385
 
384
386
  @classmethod
385
- def merge(cls, candidates=List[List], **kwargs) -> Tuple:
387
+ def merge(cls, results=List[Tuple], **kwargs) -> Tuple:
386
388
  """
387
389
  Merge multiple instances of :py:class:`PeakCaller`.
388
390
 
389
391
  Parameters
390
392
  ----------
391
- candidates : list of lists
392
- Obtained by invoking list on the generator returned by __iter__.
393
+ results : list of tuple
394
+ List of instance results created by applying `result`.
393
395
  **kwargs
394
396
  Optional keyword arguments.
395
397
 
396
398
  Returns
397
399
  -------
398
- Tuple
399
- Tuple of translation, rotation, score and details of candidates.
400
+ NDArray
401
+ Peak positions (n, d).
402
+ NDArray
403
+ Peak rotation matrices (n, d, d).
404
+ NDArray
405
+ Peak scores (n, ).
406
+ NDArray
407
+ Peak details (n,).
400
408
  """
401
409
  if "shape" not in kwargs:
402
- kwargs["shape"] = tuple(1 for _ in range(candidates[0][0].shape[1]))
410
+ kwargs["shape"] = tuple(1 for _ in range(results[0][0].shape[1]))
403
411
 
404
412
  base = cls(**kwargs)
405
- for candidate in candidates:
406
- if len(candidate) == 0:
413
+ base_state = base.init_state()
414
+ for result in results:
415
+ if len(result) == 0:
407
416
  continue
408
- peak_positions, rotations, peak_scores, peak_details = candidate
409
- base._update(
417
+ peak_positions, rotations, peak_scores, peak_details = result
418
+ base_state = base._update(
419
+ base_state,
410
420
  peak_positions=be.to_backend_array(peak_positions),
411
421
  peak_details=be.to_backend_array(peak_details),
412
422
  peak_scores=be.to_backend_array(peak_scores),
413
- rotations=be.to_backend_array(rotations),
423
+ peak_rotations=be.to_backend_array(rotations),
414
424
  offset=kwargs.get("offset", None),
415
425
  )
416
- return tuple(base)
426
+ return base_state
417
427
 
418
428
  @staticmethod
419
429
  def oversample_peaks(
@@ -520,10 +530,11 @@ class PeakCaller(ABC):
520
530
 
521
531
  def _update(
522
532
  self,
533
+ state,
523
534
  peak_positions: BackendArray,
524
535
  peak_details: BackendArray,
525
536
  peak_scores: BackendArray,
526
- rotations: BackendArray,
537
+ peak_rotations: BackendArray,
527
538
  offset: BackendArray = None,
528
539
  ):
529
540
  """
@@ -542,14 +553,15 @@ class PeakCaller(ABC):
542
553
  offset : BackendArray, optional
543
554
  Translation offset, e.g. from splitting, (d, ).
544
555
  """
556
+ translations, rotations, scores, details = state
545
557
  if offset is not None:
546
558
  offset = be.astype(be.to_backend_array(offset), peak_positions.dtype)
547
559
  peak_positions = be.add(peak_positions, offset, out=peak_positions)
548
560
 
549
- positions = be.concatenate((self.translations, peak_positions))
550
- rotations = be.concatenate((self.rotations, rotations))
551
- scores = be.concatenate((self.scores, peak_scores))
552
- details = be.concatenate((self.details, peak_details))
561
+ positions = be.concatenate((translations, peak_positions))
562
+ rotations = be.concatenate((rotations, peak_rotations))
563
+ scores = be.concatenate((scores, peak_scores))
564
+ details = be.concatenate((details, peak_details))
553
565
 
554
566
  # topk filtering after distances yields more distributed peak calls
555
567
  distance_order = filter_points_indices(
@@ -564,22 +576,69 @@ class PeakCaller(ABC):
564
576
  )
565
577
  final_order = distance_order[top_scores]
566
578
 
567
- self.translations = positions[final_order, :]
568
- self.rotations = rotations[final_order, :]
569
- self.scores = scores[final_order]
570
- self.details = details[final_order]
579
+ translations = positions[final_order, :]
580
+ rotations = rotations[final_order, :]
581
+ scores = scores[final_order]
582
+ details = details[final_order]
583
+ return translations, rotations, scores, details
571
584
 
572
- def _postprocess(self, **kwargs):
573
- if not len(self.translations):
574
- return self
585
+ def result(
586
+ self,
587
+ state,
588
+ fast_shape: Tuple[int] = None,
589
+ targetshape: Tuple[int] = None,
590
+ templateshape: Tuple[int] = None,
591
+ convolution_shape: Tuple[int] = None,
592
+ fourier_shift: Tuple[int] = None,
593
+ convolution_mode: str = None,
594
+ **kwargs,
595
+ ):
596
+ """
597
+ Finalize the analysis result with optional postprocessing.
575
598
 
576
- positions, valid_peaks = score_to_cart(self.translations, **kwargs)
599
+ Parameters
600
+ ----------
601
+ state : tuple
602
+ Current state tuple where:
603
+ - positions : BackendArray, (n, d) of peak positions
604
+ - rotations : BackendArray, (n, d, d) of correponding rotations
605
+ - scores : BackendArray, (n, ) of peak scores
606
+ - details : BackendArray, (n, ) of peak details
607
+ targetshape : Tuple[int], optional
608
+ Shape of the target for convolution mode correction.
609
+ templateshape : Tuple[int], optional
610
+ Shape of the template for convolution mode correction.
611
+ convolution_shape : Tuple[int], optional
612
+ Shape used for convolution.
613
+ fourier_shift : Tuple[int], optional.
614
+ Shift to apply for Fourier correction.
615
+ convolution_mode : str, optional
616
+ Convolution mode for padding correction.
617
+ **kwargs
618
+ Additional keyword arguments.
577
619
 
578
- self.translations = positions[valid_peaks]
579
- self.rotations = self.rotations[valid_peaks]
580
- self.scores = self.scores[valid_peaks]
581
- self.details = self.details[valid_peaks]
582
- return self
620
+ Returns
621
+ -------
622
+ tuple
623
+ Final result tuple (positions, rotations, scores, details).
624
+ """
625
+ translations, rotations, scores, details = state
626
+
627
+ positions, valid_peaks = score_to_cart(
628
+ positions=translations,
629
+ fast_shape=fast_shape,
630
+ targetshape=targetshape,
631
+ templateshape=templateshape,
632
+ convolution_shape=convolution_shape,
633
+ fourier_shift=fourier_shift,
634
+ convolution_mode=convolution_mode,
635
+ **kwargs,
636
+ )
637
+ translations = be.to_cpu_array(positions[valid_peaks])
638
+ rotations = be.to_cpu_array(rotations[valid_peaks])
639
+ scores = be.to_cpu_array(scores[valid_peaks])
640
+ details = be.to_cpu_array(details[valid_peaks])
641
+ return translations, rotations, scores, details
583
642
 
584
643
 
585
644
  class PeakCallerSort(PeakCaller):
tme/analyzer/proxy.py ADDED
@@ -0,0 +1,122 @@
1
+ """
2
+ Implements SharedAnalyzerProxy to managed shared memory of Analyzer instances
3
+ across different tasks.
4
+
5
+ This is primarily useful for CPU template matching, where parallelization can
6
+ be performed over rotations, rather than subsections of a large input volume.
7
+
8
+ Copyright (c) 2025 European Molecular Biology Laboratory
9
+
10
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
11
+ """
12
+
13
+ from typing import Tuple
14
+ from multiprocessing import Manager
15
+ from multiprocessing.shared_memory import SharedMemory
16
+
17
+ from ..backends import backend as be
18
+
19
+ __all__ = ["StatelessSharedAnalyzerProxy", "SharedAnalyzerProxy"]
20
+
21
+ class StatelessSharedAnalyzerProxy:
22
+ """
23
+ Proxy that wraps functional analyzers for concurrent access via shared memory.
24
+
25
+ Enables multiple processes/threads to safely update the same analyzer
26
+ while preserving the functional interface of the underlying analyzer.
27
+ """
28
+
29
+ def __init__(self, analyzer_class: type, analyzer_params: dict):
30
+ self._shared = False
31
+ self._process = self._direct_call
32
+
33
+ self._analyzer = analyzer_class(**analyzer_params)
34
+
35
+ def __call__(self, state, *args, **kwargs):
36
+ return self._process(state, *args, **kwargs)
37
+
38
+ def init_state(self, shm_handler=None, *args, **kwargs) -> Tuple:
39
+ state = self._analyzer.init_state()
40
+ if shm_handler is not None:
41
+ self._shared = True
42
+ state = self._to_shared(state, shm_handler)
43
+
44
+ self._lock = Manager().Lock()
45
+ self._process = self._thread_safe_call
46
+ return state
47
+
48
+ def _to_shared(self, state: Tuple, shm_handler):
49
+ backend_arr = type(be.zeros((1), dtype=be._float_dtype))
50
+
51
+ ret = []
52
+ for v in state:
53
+ if isinstance(v, backend_arr):
54
+ v = be.to_sharedarr(v, shm_handler)
55
+ elif isinstance(v, dict):
56
+ v = Manager().dict(**v)
57
+ ret.append(v)
58
+ return tuple(ret)
59
+
60
+ def _shared_to_object(self, shared: type):
61
+ if not self._shared:
62
+ return shared
63
+ if isinstance(shared, tuple) and len(shared):
64
+ if isinstance(shared[0], SharedMemory):
65
+ return be.from_sharedarr(shared)
66
+ return shared
67
+
68
+ def _thread_safe_call(self, state, *args, **kwargs):
69
+ """Thread-safe call to analyzer"""
70
+ with self._lock:
71
+ state = tuple(self._shared_to_object(x) for x in state)
72
+ return self._direct_call(state, *args, **kwargs)
73
+
74
+ def _direct_call(self, state, *args, **kwargs):
75
+ """Direct call to analyzer without locking"""
76
+ return self._analyzer(state, *args, **kwargs)
77
+
78
+ def result(self, state, **kwargs):
79
+ """Extract final result"""
80
+ final_state = state
81
+ if self._shared:
82
+ # Convert shared arrays back to regular arrays and copy to
83
+ # avoid array invalidation by shared memory handler
84
+ final_state = tuple(self._shared_to_object(x) for x in final_state)
85
+ return self._analyzer.result(final_state, **kwargs)
86
+
87
+ def merge(self, *args, **kwargs):
88
+ return self._analyzer.merge(*args, **kwargs)
89
+
90
+
91
+ class SharedAnalyzerProxy(StatelessSharedAnalyzerProxy):
92
+ """
93
+ Child of :py:class:`StatelessSharedAnalyzerProxy` that is aware
94
+ of the current analyzer state to emulate the previous analyzer interface.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ analyzer_class: type,
100
+ analyzer_params: dict,
101
+ shm_handler: type = None,
102
+ **kwargs,
103
+ ):
104
+ super().__init__(
105
+ analyzer_class=analyzer_class,
106
+ analyzer_params=analyzer_params,
107
+ )
108
+ if not self._analyzer.shareable:
109
+ shm_handler = None
110
+ self.init_state(shm_handler)
111
+
112
+ def init_state(self, shm_handler=None, *args, **kwargs) -> Tuple:
113
+ self._state = super().init_state(shm_handler, *args, **kwargs)
114
+
115
+ def __call__(self, *args, **kwargs):
116
+ state = super().__call__(self._state, *args, **kwargs)
117
+ if not self._shared:
118
+ self._state = state
119
+
120
+ def result(self, **kwargs):
121
+ """Extract final result"""
122
+ return super().result(self._state, **kwargs)
tme/backends/__init__.py CHANGED
@@ -1,8 +1,9 @@
1
- """ pyTME backend manager.
1
+ """
2
+ pyTME backend manager.
2
3
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  from typing import Dict, List