pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3.0__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.0.data/scripts/estimate_memory_usage.py +76 -0
- pytme-0.3.0.data/scripts/match_template.py +1106 -0
- {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/postprocess.py +320 -190
- {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocess.py +21 -31
- {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocessor_gui.py +85 -19
- pytme-0.3.0.data/scripts/pytme_runner.py +771 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/METADATA +22 -20
- pytme-0.3.0.dist-info/RECORD +126 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/entry_points.txt +2 -1
- pytme-0.3.0.dist-info/licenses/LICENSE +339 -0
- scripts/estimate_memory_usage.py +76 -0
- scripts/eval.py +93 -0
- scripts/extract_candidates.py +224 -0
- scripts/match_template.py +349 -378
- pytme-0.2.9.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
- scripts/postprocess.py +320 -190
- scripts/preprocess.py +21 -31
- scripts/preprocessor_gui.py +85 -19
- scripts/pytme_runner.py +771 -0
- scripts/refine_matches.py +625 -0
- tests/preprocessing/test_frequency_filters.py +28 -14
- tests/test_analyzer.py +41 -36
- tests/test_backends.py +1 -0
- tests/test_matching_cli.py +109 -53
- tests/test_matching_data.py +5 -5
- tests/test_matching_exhaustive.py +1 -2
- tests/test_matching_optimization.py +4 -9
- tests/test_matching_utils.py +1 -1
- tests/test_orientations.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +26 -21
- tme/analyzer/aggregation.py +396 -222
- tme/analyzer/base.py +127 -0
- tme/analyzer/peaks.py +189 -201
- tme/analyzer/proxy.py +123 -0
- tme/backends/__init__.py +4 -3
- tme/backends/_cupy_utils.py +25 -24
- tme/backends/_jax_utils.py +20 -18
- tme/backends/cupy_backend.py +13 -26
- tme/backends/jax_backend.py +24 -23
- tme/backends/matching_backend.py +4 -3
- tme/backends/mlx_backend.py +4 -3
- tme/backends/npfftw_backend.py +34 -30
- tme/backends/pytorch_backend.py +18 -4
- tme/cli.py +126 -0
- tme/density.py +9 -7
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/__init__.py +3 -3
- tme/filters/_utils.py +36 -10
- tme/filters/bandpass.py +229 -188
- tme/filters/compose.py +5 -4
- tme/filters/ctf.py +516 -254
- tme/filters/reconstruction.py +91 -32
- tme/filters/wedge.py +196 -135
- tme/filters/whitening.py +37 -42
- tme/matching_data.py +28 -39
- tme/matching_exhaustive.py +31 -27
- tme/matching_optimization.py +5 -4
- tme/matching_scores.py +25 -15
- tme/matching_utils.py +158 -28
- tme/memory.py +4 -3
- tme/orientations.py +22 -9
- tme/parser.py +114 -33
- tme/preprocessor.py +6 -5
- tme/rotations.py +10 -7
- tme/structure.py +4 -3
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +0 -97
- pytme-0.2.9.dist-info/RECORD +0 -119
- pytme-0.2.9.dist-info/licenses/LICENSE +0 -153
- scripts/estimate_ram_usage.py +0 -97
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/WHEEL +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/top_level.txt +0 -0
tme/filters/whitening.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Implements class BandPassFilter to create Fourier filter representations.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2024 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
9
|
from typing import Tuple, Dict
|
@@ -14,7 +15,7 @@ from scipy.ndimage import map_coordinates
|
|
14
15
|
from ..types import BackendArray
|
15
16
|
from ..backends import backend as be
|
16
17
|
from .compose import ComposableFilter
|
17
|
-
from ._utils import fftfreqn, compute_fourier_shape
|
18
|
+
from ._utils import fftfreqn, compute_fourier_shape, shift_fourier
|
18
19
|
|
19
20
|
__all__ = ["LinearWhiteningFilter"]
|
20
21
|
|
@@ -23,11 +24,6 @@ class LinearWhiteningFilter(ComposableFilter):
|
|
23
24
|
"""
|
24
25
|
Compute Fourier power spectrums and perform whitening.
|
25
26
|
|
26
|
-
Parameters
|
27
|
-
----------
|
28
|
-
**kwargs : Dict, optional
|
29
|
-
Additional keyword arguments.
|
30
|
-
|
31
27
|
References
|
32
28
|
----------
|
33
29
|
.. [1] de Teresa-Trueba, I.; Goetz, S. K.; Mattausch, A.; Stojanovska, F.; Zimmerli, C. E.;
|
@@ -38,7 +34,7 @@ class LinearWhiteningFilter(ComposableFilter):
|
|
38
34
|
13375 (2023)
|
39
35
|
"""
|
40
36
|
|
41
|
-
def __init__(self, **kwargs):
|
37
|
+
def __init__(self, *args, **kwargs):
|
42
38
|
pass
|
43
39
|
|
44
40
|
@staticmethod
|
@@ -103,13 +99,6 @@ class LinearWhiteningFilter(ComposableFilter):
|
|
103
99
|
shape_is_real_fourier: bool = True,
|
104
100
|
order: int = 1,
|
105
101
|
) -> BackendArray:
|
106
|
-
"""
|
107
|
-
References
|
108
|
-
----------
|
109
|
-
.. [1] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
|
110
|
-
R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
|
111
|
-
13375 (2023)
|
112
|
-
"""
|
113
102
|
grid = fftfreqn(
|
114
103
|
shape=shape,
|
115
104
|
sampling_rate=0.5,
|
@@ -123,11 +112,13 @@ class LinearWhiteningFilter(ComposableFilter):
|
|
123
112
|
|
124
113
|
def __call__(
|
125
114
|
self,
|
115
|
+
shape: Tuple[int],
|
126
116
|
data: BackendArray = None,
|
127
117
|
data_rfft: BackendArray = None,
|
128
118
|
n_bins: int = None,
|
129
119
|
batch_dimension: int = None,
|
130
120
|
order: int = 1,
|
121
|
+
return_real_fourier: bool = True,
|
131
122
|
**kwargs: Dict,
|
132
123
|
) -> Dict:
|
133
124
|
"""
|
@@ -135,6 +126,8 @@ class LinearWhiteningFilter(ComposableFilter):
|
|
135
126
|
|
136
127
|
Parameters
|
137
128
|
----------
|
129
|
+
shape : tuple of ints
|
130
|
+
Shape of the returned whitening filter.
|
138
131
|
data : BackendArray, optional
|
139
132
|
The input data, defaults to None.
|
140
133
|
data_rfft : BackendArray, optional
|
@@ -143,49 +136,51 @@ class LinearWhiteningFilter(ComposableFilter):
|
|
143
136
|
The number of bins for computing the spectrum, defaults to None.
|
144
137
|
batch_dimension : int, optional
|
145
138
|
Batch dimension to average over.
|
146
|
-
|
147
|
-
|
139
|
+
return_real_fourier : tuple of int
|
140
|
+
Return a shape compliant with rfft, i.e., omit the negative frequencies
|
141
|
+
terms resulting in a return shape (*shape[:-1], shape[-1]//2+1)
|
148
142
|
**kwargs : Dict
|
149
143
|
Additional keyword arguments.
|
150
144
|
|
151
145
|
Returns
|
152
146
|
-------
|
153
|
-
|
154
|
-
|
147
|
+
dict
|
148
|
+
data: BackendArray
|
149
|
+
The filter mask.
|
150
|
+
shape: tuple of ints
|
151
|
+
The requested filter shape
|
152
|
+
return_real_fourier: bool
|
153
|
+
Whether data is compliant with rfftn.
|
154
|
+
is_multiplicative_filter: bool
|
155
|
+
Whether the filter is multiplicative in Fourier space.
|
155
156
|
"""
|
156
157
|
if data_rfft is None:
|
157
|
-
data_rfft =
|
158
|
+
data_rfft = be.rfftn(data)
|
158
159
|
|
159
160
|
data_rfft = be.to_numpy_array(data_rfft)
|
160
|
-
|
161
161
|
bins, radial_averages = self._compute_spectrum(
|
162
162
|
data_rfft, n_bins, batch_dimension
|
163
163
|
)
|
164
|
+
shape = tuple(int(x) for i, x in enumerate(shape) if i != batch_dimension)
|
164
165
|
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
filter_mask[cutoff] = radial_averages[bins[cutoff]]
|
169
|
-
else:
|
170
|
-
shape = bins.shape
|
171
|
-
if kwargs.get("shape", False):
|
172
|
-
shape = compute_fourier_shape(
|
173
|
-
shape=kwargs.get("shape"),
|
174
|
-
shape_is_real_fourier=kwargs.get("shape_is_real_fourier", False),
|
175
|
-
)
|
176
|
-
|
177
|
-
filter_mask = self._interpolate_spectrum(
|
178
|
-
spectrum=radial_averages,
|
166
|
+
shape_filter = shape
|
167
|
+
if return_real_fourier:
|
168
|
+
shape_filter = compute_fourier_shape(
|
179
169
|
shape=shape,
|
180
|
-
shape_is_real_fourier=
|
170
|
+
shape_is_real_fourier=False,
|
181
171
|
)
|
182
172
|
|
183
|
-
|
184
|
-
|
185
|
-
|
173
|
+
ret = self._interpolate_spectrum(
|
174
|
+
spectrum=radial_averages,
|
175
|
+
shape=shape_filter,
|
176
|
+
shape_is_real_fourier=return_real_fourier,
|
186
177
|
)
|
187
178
|
|
179
|
+
ret = shift_fourier(data=ret, shape_is_real_fourier=return_real_fourier)
|
180
|
+
|
188
181
|
return {
|
189
|
-
"data": be.to_backend_array(
|
182
|
+
"data": be.to_backend_array(ret),
|
183
|
+
"shape": shape,
|
184
|
+
"return_real_fourier": return_real_fourier,
|
190
185
|
"is_multiplicative_filter": True,
|
191
186
|
}
|
tme/matching_data.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Class representation of template matching data.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
9
|
import warnings
|
@@ -174,7 +175,7 @@ class MatchingData:
|
|
174
175
|
target_pad: NDArray = None,
|
175
176
|
template_pad: NDArray = None,
|
176
177
|
invert_target: bool = False,
|
177
|
-
) -> "MatchingData":
|
178
|
+
) -> Tuple["MatchingData", Tuple]:
|
178
179
|
"""
|
179
180
|
Subset class instance based on slices.
|
180
181
|
|
@@ -193,6 +194,8 @@ class MatchingData:
|
|
193
194
|
-------
|
194
195
|
:py:class:`MatchingData`
|
195
196
|
Newly allocated subset of class instance.
|
197
|
+
Tuple
|
198
|
+
Translation offset to merge analyzers.
|
196
199
|
|
197
200
|
Examples
|
198
201
|
--------
|
@@ -250,8 +253,9 @@ class MatchingData:
|
|
250
253
|
target_offset[mask] = [x.start for x in target_slice]
|
251
254
|
mask = np.subtract(1, self._target_batch).astype(bool)
|
252
255
|
template_offset = np.zeros(len(self._output_template_shape), dtype=int)
|
253
|
-
template_offset[mask] = [x.start for x in template_slice]
|
254
|
-
|
256
|
+
template_offset[mask] = [x.start for x, b in zip(template_slice, mask) if b]
|
257
|
+
|
258
|
+
translation_offset = tuple(x for x in target_offset)
|
255
259
|
|
256
260
|
ret.target_filter = self.target_filter
|
257
261
|
ret.template_filter = self.template_filter
|
@@ -261,7 +265,7 @@ class MatchingData:
|
|
261
265
|
template_dim=getattr(self, "_template_dim", None),
|
262
266
|
)
|
263
267
|
|
264
|
-
return ret
|
268
|
+
return ret, translation_offset
|
265
269
|
|
266
270
|
def to_backend(self):
|
267
271
|
"""
|
@@ -322,11 +326,6 @@ class MatchingData:
|
|
322
326
|
|
323
327
|
target_ndim -= len(target_dims)
|
324
328
|
template_ndim -= len(template_dims)
|
325
|
-
|
326
|
-
if target_ndim != template_ndim:
|
327
|
-
raise ValueError(
|
328
|
-
f"Dimension mismatch: Target ({target_ndim}) Template ({template_ndim})."
|
329
|
-
)
|
330
329
|
self._set_matching_dimension(
|
331
330
|
target_dims=target_dims, template_dims=template_dims
|
332
331
|
)
|
@@ -491,29 +490,26 @@ class MatchingData:
|
|
491
490
|
def _fourier_padding(
|
492
491
|
target_shape: Tuple[int],
|
493
492
|
template_shape: Tuple[int],
|
494
|
-
|
493
|
+
pad_target: bool = False,
|
495
494
|
batch_mask: Tuple[int] = None,
|
496
495
|
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
|
497
|
-
fourier_pad = template_shape
|
498
|
-
fourier_shift = np.zeros_like(template_shape)
|
499
|
-
|
500
496
|
if batch_mask is None:
|
501
497
|
batch_mask = np.zeros_like(template_shape)
|
502
498
|
batch_mask = np.asarray(batch_mask)
|
503
499
|
|
504
|
-
|
505
|
-
fourier_pad = np.ones(len(fourier_pad), dtype=int)
|
500
|
+
fourier_pad = np.ones(len(template_shape), dtype=int)
|
506
501
|
fourier_pad = np.multiply(fourier_pad, 1 - batch_mask)
|
507
502
|
fourier_pad = np.add(fourier_pad, batch_mask)
|
508
503
|
|
504
|
+
# Avoid padding batch dimensions
|
509
505
|
pad_shape = np.maximum(target_shape, template_shape)
|
506
|
+
pad_shape = np.maximum(target_shape, np.multiply(1 - batch_mask, pad_shape))
|
510
507
|
ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
|
511
508
|
conv_shape, fast_shape, fast_ft_shape = ret
|
512
509
|
|
513
510
|
template_mod = np.mod(template_shape, 2)
|
514
|
-
|
515
|
-
|
516
|
-
fourier_shift = np.subtract(fourier_shift, template_mod)
|
511
|
+
fourier_shift = 1 - np.divide(template_shape, 2).astype(int)
|
512
|
+
fourier_shift = np.subtract(fourier_shift, template_mod)
|
517
513
|
|
518
514
|
shape_diff = np.multiply(
|
519
515
|
np.subtract(target_shape, template_shape), 1 - batch_mask
|
@@ -522,34 +518,27 @@ class MatchingData:
|
|
522
518
|
if np.sum(shape_mask):
|
523
519
|
shape_shift = np.divide(shape_diff, 2)
|
524
520
|
offset = np.mod(shape_diff, 2)
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
)
|
530
|
-
else:
|
531
|
-
warnings.warn(
|
532
|
-
"Template is larger than target and padding is turned off. Consider "
|
533
|
-
"swapping them or activate padding. Correcting the shift for now."
|
534
|
-
)
|
521
|
+
warnings.warn(
|
522
|
+
"Template is larger than target and padding is turned off. Consider "
|
523
|
+
"swapping them or activate padding. Correcting the shift for now."
|
524
|
+
)
|
535
525
|
shape_shift = np.multiply(np.add(shape_shift, offset), shape_mask)
|
536
526
|
fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
|
537
527
|
|
538
|
-
|
528
|
+
if pad_target:
|
529
|
+
fourier_shift = np.subtract(fourier_shift, np.subtract(1, template_mod))
|
539
530
|
|
531
|
+
fourier_shift = tuple(np.multiply(fourier_shift, 1 - batch_mask).astype(int))
|
540
532
|
return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
|
541
533
|
|
542
|
-
def fourier_padding(
|
543
|
-
self, pad_fourier: bool = False
|
544
|
-
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
|
534
|
+
def fourier_padding(self, pad_target: bool = False) -> Tuple:
|
545
535
|
"""
|
546
536
|
Computes efficient shape four Fourier transforms and potential associated shifts.
|
547
537
|
|
548
538
|
Parameters
|
549
539
|
----------
|
550
|
-
|
551
|
-
|
552
|
-
shape and template shape minus one, False by default.
|
540
|
+
pad_target : bool, optional
|
541
|
+
Whether the target has been padded to the full convolution shape.
|
553
542
|
|
554
543
|
Returns
|
555
544
|
-------
|
@@ -564,7 +553,7 @@ class MatchingData:
|
|
564
553
|
target_shape=be.to_numpy_array(self._output_target_shape),
|
565
554
|
template_shape=be.to_numpy_array(self._output_template_shape),
|
566
555
|
batch_mask=be.to_numpy_array(self._batch_mask),
|
567
|
-
|
556
|
+
pad_target=pad_target,
|
568
557
|
)
|
569
558
|
|
570
559
|
def computation_schedule(
|
tme/matching_exhaustive.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Implements cross-correlation based template matching using different metrics.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
9
|
import sys
|
@@ -19,6 +20,7 @@ from .filters import Compose
|
|
19
20
|
from .backends import backend as be
|
20
21
|
from .matching_utils import split_shape
|
21
22
|
from .types import CallbackClass, MatchingData
|
23
|
+
from .analyzer.proxy import SharedAnalyzerProxy
|
22
24
|
from .matching_scores import MATCHING_EXHAUSTIVE_REGISTER
|
23
25
|
from .memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
|
24
26
|
|
@@ -147,7 +149,7 @@ def scan(
|
|
147
149
|
n_jobs: int = 4,
|
148
150
|
callback_class: CallbackClass = None,
|
149
151
|
callback_class_args: Dict = {},
|
150
|
-
|
152
|
+
pad_target: bool = True,
|
151
153
|
pad_template_filter: bool = True,
|
152
154
|
interpolation_order: int = 3,
|
153
155
|
jobs_per_callback_class: int = 8,
|
@@ -174,8 +176,8 @@ def scan(
|
|
174
176
|
Analyzer class pointer to operate on computed scores.
|
175
177
|
callback_class_args : dict, optional
|
176
178
|
Arguments passed to the callback_class. Default is an empty dictionary.
|
177
|
-
|
178
|
-
Whether to pad target
|
179
|
+
pad_target: bool, optional
|
180
|
+
Whether to pad target to the full convolution shape.
|
179
181
|
pad_template_filter: bool, optional
|
180
182
|
Whether to pad potential template filters to the full convolution shape.
|
181
183
|
interpolation_order : int, optional
|
@@ -208,17 +210,17 @@ def scan(
|
|
208
210
|
>>> )
|
209
211
|
|
210
212
|
"""
|
211
|
-
matching_data = matching_data.subset_by_slice(
|
213
|
+
matching_data, translation_offset = matching_data.subset_by_slice(
|
212
214
|
target_slice=target_slice,
|
213
215
|
template_slice=template_slice,
|
214
|
-
target_pad=matching_data.target_padding(pad_target=
|
216
|
+
target_pad=matching_data.target_padding(pad_target=pad_target),
|
215
217
|
)
|
216
218
|
|
217
219
|
matching_data.to_backend()
|
218
220
|
template_shape = matching_data._batch_shape(
|
219
221
|
matching_data.template.shape, matching_data._template_batch
|
220
222
|
)
|
221
|
-
conv, fwd, inv, shift = matching_data.fourier_padding(
|
223
|
+
conv, fwd, inv, shift = matching_data.fourier_padding(pad_target=pad_target)
|
222
224
|
|
223
225
|
template_filter = _setup_template_filter_apply_target_filter(
|
224
226
|
matching_data=matching_data,
|
@@ -229,19 +231,20 @@ def scan(
|
|
229
231
|
|
230
232
|
default_callback_args = {
|
231
233
|
"shape": fwd,
|
232
|
-
"offset":
|
234
|
+
"offset": translation_offset,
|
233
235
|
"fourier_shift": shift,
|
234
236
|
"fast_shape": fwd,
|
235
237
|
"targetshape": matching_data._output_shape,
|
236
238
|
"templateshape": template_shape,
|
237
239
|
"convolution_shape": conv,
|
238
240
|
"thread_safe": n_jobs > 1,
|
239
|
-
"convolution_mode": "valid" if
|
241
|
+
"convolution_mode": "valid" if pad_target else "same",
|
240
242
|
"shm_handler": shm_handler,
|
241
243
|
"only_unique_rotations": True,
|
242
244
|
"aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
|
243
245
|
"n_rotations": matching_data.rotations.shape[0],
|
244
246
|
}
|
247
|
+
callback_class_args["inversion_mapping"] = n_jobs == 1
|
245
248
|
default_callback_args.update(callback_class_args)
|
246
249
|
|
247
250
|
setup = matching_setup(
|
@@ -257,13 +260,17 @@ def scan(
|
|
257
260
|
matching_data._free_data()
|
258
261
|
be.free_cache()
|
259
262
|
|
260
|
-
# Some analyzers cannot be shared across processes
|
261
|
-
if not getattr(callback_class, "is_shareable", False):
|
262
|
-
jobs_per_callback_class = 1
|
263
|
-
|
264
263
|
n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
|
265
264
|
callback_classes = [
|
266
|
-
|
265
|
+
(
|
266
|
+
SharedAnalyzerProxy(
|
267
|
+
callback_class,
|
268
|
+
default_callback_args,
|
269
|
+
shm_handler=shm_handler if n_jobs > 1 else None,
|
270
|
+
)
|
271
|
+
if callback_class
|
272
|
+
else None
|
273
|
+
)
|
267
274
|
for _ in range(n_callback_classes)
|
268
275
|
]
|
269
276
|
ret = Parallel(n_jobs=n_jobs)(
|
@@ -276,14 +283,9 @@ def scan(
|
|
276
283
|
)
|
277
284
|
for index, rotation in enumerate(matching_data._split_rotations_on_jobs(n_jobs))
|
278
285
|
)
|
279
|
-
|
280
|
-
# TODO: Make sure peak callers are thread safe to begin with
|
281
|
-
if not getattr(callback_class, "is_shareable", False):
|
282
|
-
callback_classes = ret
|
283
|
-
|
284
286
|
callbacks = [
|
285
|
-
|
286
|
-
for callback in
|
287
|
+
callback.result(**default_callback_args)
|
288
|
+
for callback in ret[:n_callback_classes]
|
287
289
|
if callback
|
288
290
|
]
|
289
291
|
be.free_cache()
|
@@ -423,14 +425,17 @@ def scan_subsets(
|
|
423
425
|
splits = tuple(product(target_splits, template_splits))
|
424
426
|
|
425
427
|
outer_jobs, inner_jobs = job_schedule
|
426
|
-
if
|
428
|
+
if be._backend_name == "jax":
|
429
|
+
func = be.scan
|
430
|
+
|
427
431
|
corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
|
428
|
-
results =
|
432
|
+
results = func(
|
429
433
|
matching_data=matching_data,
|
430
434
|
splits=splits,
|
431
435
|
n_jobs=outer_jobs,
|
432
436
|
rotate_mask=matching_score != corr_scoring,
|
433
437
|
callback_class=callback_class,
|
438
|
+
callback_class_args=callback_class_args,
|
434
439
|
)
|
435
440
|
else:
|
436
441
|
results = Parallel(n_jobs=outer_jobs, verbose=verbose)(
|
@@ -445,7 +450,7 @@ def scan_subsets(
|
|
445
450
|
callback_class=callback_class,
|
446
451
|
callback_class_args=callback_class_args,
|
447
452
|
interpolation_order=interpolation_order,
|
448
|
-
|
453
|
+
pad_target=pad_target_edges,
|
449
454
|
gpu_index=index % outer_jobs,
|
450
455
|
pad_template_filter=pad_template_filter,
|
451
456
|
target_slice=target_split,
|
@@ -454,7 +459,6 @@ def scan_subsets(
|
|
454
459
|
for index, (target_split, template_split) in enumerate(splits)
|
455
460
|
]
|
456
461
|
)
|
457
|
-
|
458
462
|
matching_data._free_data()
|
459
463
|
if callback_class is not None:
|
460
464
|
return callback_class.merge(results, **callback_class_args)
|
tme/matching_optimization.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Implements methods for non-exhaustive template matching.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
9
|
import warnings
|
@@ -1308,4 +1309,4 @@ def optimize_match(
|
|
1308
1309
|
result.x = np.zeros_like(result.x)
|
1309
1310
|
translation, rotation = result.x[:ndim], result.x[ndim:]
|
1310
1311
|
rotation_matrix = euler_to_rotationmatrix(rotation)
|
1311
|
-
return translation, rotation_matrix, result.fun
|
1312
|
+
return translation, rotation_matrix, float(result.fun)
|
tme/matching_scores.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Implements a range of cross-correlation coefficients.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023-2024 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
9
|
import warnings
|
@@ -592,20 +593,27 @@ def corr_scoring(
|
|
592
593
|
**_fftargs,
|
593
594
|
)
|
594
595
|
|
596
|
+
center = be.divide(be.to_backend_array(template.shape) - 1, 2)
|
595
597
|
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
598
|
+
|
599
|
+
template_rot = be.zeros(template.shape, be._float_dtype)
|
596
600
|
for index in range(rotations.shape[0]):
|
601
|
+
# d+1, d+1 rigid transform matrix from d,d rotation matrix
|
597
602
|
rotation = rotations[index]
|
598
|
-
|
599
|
-
|
603
|
+
matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
|
604
|
+
template_rot, _ = be.rigid_transform(
|
600
605
|
arr=template,
|
601
|
-
rotation_matrix=
|
602
|
-
out=
|
603
|
-
use_geometric_center=True,
|
606
|
+
rotation_matrix=matrix,
|
607
|
+
out=template_rot,
|
604
608
|
order=interpolation_order,
|
605
|
-
cache=
|
609
|
+
cache=True,
|
606
610
|
)
|
607
|
-
|
608
|
-
|
611
|
+
|
612
|
+
template_rot = template_filter_func(template_rot, ft_temp, template_filter)
|
613
|
+
norm_template(template_rot, template_mask, mask_sum)
|
614
|
+
|
615
|
+
arr = be.fill(arr, 0)
|
616
|
+
arr[unpadded_slice] = template_rot
|
609
617
|
|
610
618
|
ft_temp = rfftn(arr, ft_temp)
|
611
619
|
ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp)
|
@@ -728,7 +736,7 @@ def flc_scoring(
|
|
728
736
|
out_mask=temp,
|
729
737
|
use_geometric_center=True,
|
730
738
|
order=interpolation_order,
|
731
|
-
cache=
|
739
|
+
cache=True,
|
732
740
|
)
|
733
741
|
|
734
742
|
n_obs = be.sum(temp)
|
@@ -874,7 +882,7 @@ def mcc_scoring(
|
|
874
882
|
out_mask=temp,
|
875
883
|
use_geometric_center=True,
|
876
884
|
order=interpolation_order,
|
877
|
-
cache=
|
885
|
+
cache=True,
|
878
886
|
)
|
879
887
|
|
880
888
|
template_filter_func(template_rot, temp_ft, template_filter)
|
@@ -1034,7 +1042,8 @@ def flc_scoring2(
|
|
1034
1042
|
out_mask=tmp_sqz,
|
1035
1043
|
use_geometric_center=True,
|
1036
1044
|
order=interpolation_order,
|
1037
|
-
cache=
|
1045
|
+
cache=True,
|
1046
|
+
batched=True,
|
1038
1047
|
)
|
1039
1048
|
n_obs = be.sum(tmp_sqz, axis=data_axes, keepdims=True)
|
1040
1049
|
arr_norm = template_filter_func(arr_sqz, ft_temp, template_filter)
|
@@ -1154,7 +1163,8 @@ def corr_scoring2(
|
|
1154
1163
|
out=arr_sqz,
|
1155
1164
|
use_geometric_center=True,
|
1156
1165
|
order=interpolation_order,
|
1157
|
-
cache=
|
1166
|
+
cache=True,
|
1167
|
+
batched=True,
|
1158
1168
|
)
|
1159
1169
|
arr_norm = template_filter_func(arr_sqz, ft_sqz, template_filter)
|
1160
1170
|
norm_template(arr_norm[unpadded_slice], template_mask, mask_sum, axis=data_axes)
|