pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
- pytme-0.2.9.data/scripts/match_template.py +1135 -0
- pytme-0.2.9.data/scripts/postprocess.py +622 -0
- pytme-0.2.9.data/scripts/preprocess.py +209 -0
- pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
- pytme-0.2.9.dist-info/METADATA +95 -0
- pytme-0.2.9.dist-info/RECORD +119 -0
- pytme-0.2.9.dist-info/WHEEL +5 -0
- pytme-0.2.9.dist-info/entry_points.txt +6 -0
- pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
- pytme-0.2.9.dist-info/top_level.txt +3 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +1135 -0
- scripts/postprocess.py +622 -0
- scripts/preprocess.py +209 -0
- scripts/preprocessor_gui.py +1227 -0
- tests/__init__.py +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +216 -0
- tests/test_backends.py +446 -0
- tests/test_density.py +503 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +124 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +226 -0
- tests/test_matching_utils.py +189 -0
- tests/test_orientations.py +175 -0
- tests/test_parser.py +33 -0
- tests/test_rotations.py +153 -0
- tests/test_structure.py +247 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +186 -0
- tme/analyzer/aggregation.py +577 -0
- tme/analyzer/peaks.py +953 -0
- tme/backends/__init__.py +171 -0
- tme/backends/_cupy_utils.py +734 -0
- tme/backends/_jax_utils.py +188 -0
- tme/backends/cupy_backend.py +294 -0
- tme/backends/jax_backend.py +314 -0
- tme/backends/matching_backend.py +1270 -0
- tme/backends/mlx_backend.py +241 -0
- tme/backends/npfftw_backend.py +583 -0
- tme/backends/pytorch_backend.py +430 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2263 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/filters/__init__.py +6 -0
- tme/filters/_utils.py +311 -0
- tme/filters/bandpass.py +230 -0
- tme/filters/compose.py +81 -0
- tme/filters/ctf.py +393 -0
- tme/filters/reconstruction.py +160 -0
- tme/filters/wedge.py +542 -0
- tme/filters/whitening.py +191 -0
- tme/matching_data.py +863 -0
- tme/matching_exhaustive.py +497 -0
- tme/matching_optimization.py +1311 -0
- tme/matching_scores.py +1183 -0
- tme/matching_utils.py +1188 -0
- tme/memory.py +337 -0
- tme/orientations.py +598 -0
- tme/parser.py +685 -0
- tme/preprocessor.py +1329 -0
- tme/rotations.py +350 -0
- tme/structure.py +1864 -0
- tme/types.py +13 -0
@@ -0,0 +1,497 @@
|
|
1
|
+
""" Implements cross-correlation based template matching using different metrics.
|
2
|
+
|
3
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
import sys
|
9
|
+
import warnings
|
10
|
+
from math import prod
|
11
|
+
from functools import wraps
|
12
|
+
from itertools import product
|
13
|
+
from typing import Callable, Tuple, Dict, Optional
|
14
|
+
|
15
|
+
from joblib import Parallel, delayed
|
16
|
+
from multiprocessing.managers import SharedMemoryManager
|
17
|
+
|
18
|
+
from .filters import Compose
|
19
|
+
from .backends import backend as be
|
20
|
+
from .matching_utils import split_shape
|
21
|
+
from .types import CallbackClass, MatchingData
|
22
|
+
from .matching_scores import MATCHING_EXHAUSTIVE_REGISTER
|
23
|
+
from .memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
|
24
|
+
|
25
|
+
|
26
|
+
def _wrap_backend(func):
|
27
|
+
@wraps(func)
|
28
|
+
def wrapper(*args, backend_name: str, backend_args: Dict, **kwargs):
|
29
|
+
from tme.backends import backend as be
|
30
|
+
|
31
|
+
be.change_backend(backend_name, **backend_args)
|
32
|
+
return func(*args, **kwargs)
|
33
|
+
|
34
|
+
return wrapper
|
35
|
+
|
36
|
+
|
37
|
+
def _setup_template_filter_apply_target_filter(
|
38
|
+
matching_data: MatchingData,
|
39
|
+
fast_shape: Tuple[int],
|
40
|
+
fast_ft_shape: Tuple[int],
|
41
|
+
pad_template_filter: bool = True,
|
42
|
+
):
|
43
|
+
target_filter = None
|
44
|
+
backend_arr = type(be.zeros((1), dtype=be._float_dtype))
|
45
|
+
template_filter = be.full(shape=(1,), fill_value=1, dtype=be._float_dtype)
|
46
|
+
if isinstance(matching_data.template_filter, backend_arr):
|
47
|
+
template_filter = matching_data.template_filter
|
48
|
+
|
49
|
+
if isinstance(matching_data.target_filter, backend_arr):
|
50
|
+
target_filter = matching_data.target_filter
|
51
|
+
|
52
|
+
filter_template = isinstance(matching_data.template_filter, Compose)
|
53
|
+
filter_target = isinstance(matching_data.target_filter, Compose)
|
54
|
+
|
55
|
+
# For now assume user-supplied template_filter is correctly padded
|
56
|
+
if filter_target is None and target_filter is None:
|
57
|
+
return template_filter
|
58
|
+
|
59
|
+
cmpl_template_shape_full, batch_mask = fast_ft_shape, matching_data._batch_mask
|
60
|
+
real_shape = matching_data._batch_shape(fast_shape, batch_mask, keepdims=False)
|
61
|
+
cmpl_shape = matching_data._batch_shape(fast_ft_shape, batch_mask, keepdims=True)
|
62
|
+
|
63
|
+
real_template_shape, cmpl_template_shape = real_shape, cmpl_shape
|
64
|
+
cmpl_template_shape_full = matching_data._batch_shape(
|
65
|
+
fast_ft_shape, matching_data._target_batch, keepdims=True
|
66
|
+
)
|
67
|
+
cmpl_target_shape_full = matching_data._batch_shape(
|
68
|
+
fast_ft_shape, matching_data._template_batch, keepdims=True
|
69
|
+
)
|
70
|
+
if filter_template and not pad_template_filter:
|
71
|
+
out_shape = matching_data._output_template_shape
|
72
|
+
real_template_shape = matching_data._batch_shape(
|
73
|
+
out_shape, batch_mask, keepdims=False
|
74
|
+
)
|
75
|
+
cmpl_template_shape = list(
|
76
|
+
matching_data._batch_shape(out_shape, batch_mask, keepdims=True)
|
77
|
+
)
|
78
|
+
cmpl_template_shape_full = list(out_shape)
|
79
|
+
cmpl_template_shape[-1] = cmpl_template_shape[-1] // 2 + 1
|
80
|
+
cmpl_template_shape_full[-1] = cmpl_template_shape_full[-1] // 2 + 1
|
81
|
+
|
82
|
+
# Setup composable filters
|
83
|
+
target_temp = be.topleft_pad(matching_data.target, fast_shape)
|
84
|
+
target_temp_ft = be.rfftn(target_temp)
|
85
|
+
filter_kwargs = {
|
86
|
+
"return_real_fourier": True,
|
87
|
+
"shape_is_real_fourier": False,
|
88
|
+
"data_rfft": target_temp_ft,
|
89
|
+
"batch_dimension": matching_data._target_dim,
|
90
|
+
}
|
91
|
+
|
92
|
+
if filter_template:
|
93
|
+
template_filter = matching_data.template_filter(
|
94
|
+
shape=real_template_shape, **filter_kwargs
|
95
|
+
)["data"]
|
96
|
+
template_filter_size = int(be.size(template_filter))
|
97
|
+
|
98
|
+
if template_filter_size == prod(cmpl_template_shape_full):
|
99
|
+
cmpl_template_shape = cmpl_template_shape_full
|
100
|
+
elif template_filter_size == prod(cmpl_shape):
|
101
|
+
cmpl_template_shape = cmpl_shape
|
102
|
+
template_filter = be.reshape(template_filter, cmpl_template_shape)
|
103
|
+
|
104
|
+
if filter_target:
|
105
|
+
target_filter = matching_data.target_filter(
|
106
|
+
shape=real_shape, weight_type=None, **filter_kwargs
|
107
|
+
)["data"]
|
108
|
+
if int(be.size(target_filter)) == prod(cmpl_target_shape_full):
|
109
|
+
cmpl_shape = cmpl_target_shape_full
|
110
|
+
|
111
|
+
target_filter = be.reshape(target_filter, cmpl_shape)
|
112
|
+
target_temp_ft = be.multiply(target_temp_ft, target_filter, out=target_temp_ft)
|
113
|
+
|
114
|
+
target_temp = be.irfftn(target_temp_ft, s=target_temp.shape)
|
115
|
+
matching_data._target = be.topleft_pad(target_temp, matching_data.target.shape)
|
116
|
+
|
117
|
+
return be.astype(be.to_backend_array(template_filter), be._float_dtype)
|
118
|
+
|
119
|
+
|
120
|
+
def device_memory_handler(func: Callable):
|
121
|
+
"""Decorator function providing SharedMemory Handler."""
|
122
|
+
|
123
|
+
@wraps(func)
|
124
|
+
def inner_function(*args, **kwargs):
|
125
|
+
return_value = None
|
126
|
+
last_type, last_value, last_traceback = sys.exc_info()
|
127
|
+
try:
|
128
|
+
with SharedMemoryManager() as smh:
|
129
|
+
gpu_index = kwargs.pop("gpu_index") if "gpu_index" in kwargs else 0
|
130
|
+
with be.set_device(gpu_index):
|
131
|
+
return_value = func(shm_handler=smh, *args, **kwargs)
|
132
|
+
except Exception:
|
133
|
+
last_type, last_value, last_traceback = sys.exc_info()
|
134
|
+
finally:
|
135
|
+
if last_type is not None:
|
136
|
+
raise last_value.with_traceback(last_traceback)
|
137
|
+
return return_value
|
138
|
+
|
139
|
+
return inner_function
|
140
|
+
|
141
|
+
|
142
|
+
@device_memory_handler
|
143
|
+
def scan(
|
144
|
+
matching_data: MatchingData,
|
145
|
+
matching_setup: Callable,
|
146
|
+
matching_score: Callable,
|
147
|
+
n_jobs: int = 4,
|
148
|
+
callback_class: CallbackClass = None,
|
149
|
+
callback_class_args: Dict = {},
|
150
|
+
pad_fourier: bool = True,
|
151
|
+
pad_template_filter: bool = True,
|
152
|
+
interpolation_order: int = 3,
|
153
|
+
jobs_per_callback_class: int = 8,
|
154
|
+
shm_handler=None,
|
155
|
+
target_slice=None,
|
156
|
+
template_slice=None,
|
157
|
+
) -> Optional[Tuple]:
|
158
|
+
"""
|
159
|
+
Run template matching.
|
160
|
+
|
161
|
+
.. warning:: ``matching_data`` might be altered or destroyed during computation.
|
162
|
+
|
163
|
+
Parameters
|
164
|
+
----------
|
165
|
+
matching_data : :py:class:`tme.matching_data.MatchingData`
|
166
|
+
Template matching data.
|
167
|
+
matching_setup : Callable
|
168
|
+
Function pointer to setup function.
|
169
|
+
matching_score : Callable
|
170
|
+
Function pointer to scoring function.
|
171
|
+
n_jobs : int, optional
|
172
|
+
Number of parallel jobs. Default is 4.
|
173
|
+
callback_class : type, optional
|
174
|
+
Analyzer class pointer to operate on computed scores.
|
175
|
+
callback_class_args : dict, optional
|
176
|
+
Arguments passed to the callback_class. Default is an empty dictionary.
|
177
|
+
pad_fourier: bool, optional
|
178
|
+
Whether to pad target and template to the full convolution shape.
|
179
|
+
pad_template_filter: bool, optional
|
180
|
+
Whether to pad potential template filters to the full convolution shape.
|
181
|
+
interpolation_order : int, optional
|
182
|
+
Order of spline interpolation for rotations.
|
183
|
+
jobs_per_callback_class : int, optional
|
184
|
+
Number of jobs a callback_class instance is shared between, 8 by default.
|
185
|
+
shm_handler : type, optional
|
186
|
+
Manager for shared memory objects, None by default.
|
187
|
+
|
188
|
+
Returns
|
189
|
+
-------
|
190
|
+
Optional[Tuple]
|
191
|
+
The merged results from callback_class if provided otherwise None.
|
192
|
+
|
193
|
+
Examples
|
194
|
+
--------
|
195
|
+
Schematically, :py:meth:`scan` is identical to :py:meth:`scan_subsets`,
|
196
|
+
with the distinction that the objects contained in ``matching_data`` are not
|
197
|
+
split and the search is only parallelized over angles.
|
198
|
+
Assuming you have followed the example in :py:meth:`scan_subsets`, :py:meth:`scan`
|
199
|
+
can be invoked like so
|
200
|
+
|
201
|
+
>>> from tme.matching_exhaustive import scan
|
202
|
+
>>> results = scan(
|
203
|
+
>>> matching_data=matching_data,
|
204
|
+
>>> matching_score=matching_score,
|
205
|
+
>>> matching_setup=matching_setup,
|
206
|
+
>>> callback_class=callback_class,
|
207
|
+
>>> callback_class_args=callback_class_args,
|
208
|
+
>>> )
|
209
|
+
|
210
|
+
"""
|
211
|
+
matching_data = matching_data.subset_by_slice(
|
212
|
+
target_slice=target_slice,
|
213
|
+
template_slice=template_slice,
|
214
|
+
target_pad=matching_data.target_padding(pad_target=pad_fourier),
|
215
|
+
)
|
216
|
+
|
217
|
+
matching_data.to_backend()
|
218
|
+
template_shape = matching_data._batch_shape(
|
219
|
+
matching_data.template.shape, matching_data._template_batch
|
220
|
+
)
|
221
|
+
conv, fwd, inv, shift = matching_data.fourier_padding(pad_fourier=False)
|
222
|
+
|
223
|
+
template_filter = _setup_template_filter_apply_target_filter(
|
224
|
+
matching_data=matching_data,
|
225
|
+
fast_shape=fwd,
|
226
|
+
fast_ft_shape=inv,
|
227
|
+
pad_template_filter=pad_template_filter,
|
228
|
+
)
|
229
|
+
|
230
|
+
default_callback_args = {
|
231
|
+
"shape": fwd,
|
232
|
+
"offset": matching_data._translation_offset,
|
233
|
+
"fourier_shift": shift,
|
234
|
+
"fast_shape": fwd,
|
235
|
+
"targetshape": matching_data._output_shape,
|
236
|
+
"templateshape": template_shape,
|
237
|
+
"convolution_shape": conv,
|
238
|
+
"thread_safe": n_jobs > 1,
|
239
|
+
"convolution_mode": "valid" if pad_fourier else "same",
|
240
|
+
"shm_handler": shm_handler,
|
241
|
+
"only_unique_rotations": True,
|
242
|
+
"aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
|
243
|
+
"n_rotations": matching_data.rotations.shape[0],
|
244
|
+
}
|
245
|
+
default_callback_args.update(callback_class_args)
|
246
|
+
|
247
|
+
setup = matching_setup(
|
248
|
+
matching_data=matching_data,
|
249
|
+
template_filter=template_filter,
|
250
|
+
fast_shape=fwd,
|
251
|
+
fast_ft_shape=inv,
|
252
|
+
shm_handler=shm_handler,
|
253
|
+
)
|
254
|
+
setup["interpolation_order"] = interpolation_order
|
255
|
+
setup["template_filter"] = be.to_sharedarr(template_filter, shm_handler)
|
256
|
+
|
257
|
+
matching_data._free_data()
|
258
|
+
be.free_cache()
|
259
|
+
|
260
|
+
# Some analyzers cannot be shared across processes
|
261
|
+
if not getattr(callback_class, "is_shareable", False):
|
262
|
+
jobs_per_callback_class = 1
|
263
|
+
|
264
|
+
n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
|
265
|
+
callback_classes = [
|
266
|
+
callback_class(**default_callback_args) if callback_class else None
|
267
|
+
for _ in range(n_callback_classes)
|
268
|
+
]
|
269
|
+
ret = Parallel(n_jobs=n_jobs)(
|
270
|
+
delayed(_wrap_backend(matching_score))(
|
271
|
+
backend_name=be._backend_name,
|
272
|
+
backend_args=be._backend_args,
|
273
|
+
rotations=rotation,
|
274
|
+
callback=callback_classes[index % n_callback_classes],
|
275
|
+
**setup,
|
276
|
+
)
|
277
|
+
for index, rotation in enumerate(matching_data._split_rotations_on_jobs(n_jobs))
|
278
|
+
)
|
279
|
+
|
280
|
+
# TODO: Make sure peak callers are thread safe to begin with
|
281
|
+
if not getattr(callback_class, "is_shareable", False):
|
282
|
+
callback_classes = ret
|
283
|
+
|
284
|
+
callbacks = [
|
285
|
+
tuple(callback._postprocess(**default_callback_args))
|
286
|
+
for callback in callback_classes
|
287
|
+
if callback
|
288
|
+
]
|
289
|
+
be.free_cache()
|
290
|
+
|
291
|
+
if callback_class:
|
292
|
+
ret = callback_class.merge(callbacks, **default_callback_args)
|
293
|
+
return ret
|
294
|
+
|
295
|
+
|
296
|
+
def scan_subsets(
|
297
|
+
matching_data: MatchingData,
|
298
|
+
matching_score: Callable,
|
299
|
+
matching_setup: Callable,
|
300
|
+
callback_class: CallbackClass = None,
|
301
|
+
callback_class_args: Dict = {},
|
302
|
+
job_schedule: Tuple[int] = (1, 1),
|
303
|
+
target_splits: Dict = {},
|
304
|
+
template_splits: Dict = {},
|
305
|
+
pad_target_edges: bool = False,
|
306
|
+
pad_template_filter: bool = True,
|
307
|
+
interpolation_order: int = 3,
|
308
|
+
jobs_per_callback_class: int = 8,
|
309
|
+
backend_name: str = None,
|
310
|
+
backend_args: Dict = {},
|
311
|
+
verbose: bool = False,
|
312
|
+
**kwargs,
|
313
|
+
) -> Optional[Tuple]:
|
314
|
+
"""
|
315
|
+
Wrapper around :py:meth:`scan` that supports matching on splits
|
316
|
+
of ``matching_data``.
|
317
|
+
|
318
|
+
Parameters
|
319
|
+
----------
|
320
|
+
matching_data : :py:class:`tme.matching_data.MatchingData`
|
321
|
+
MatchingData instance containing relevant data.
|
322
|
+
matching_setup : type
|
323
|
+
Function pointer to setup function.
|
324
|
+
matching_score : type
|
325
|
+
Function pointer to scoring function.
|
326
|
+
callback_class : type, optional
|
327
|
+
Analyzer class pointer to operate on computed scores.
|
328
|
+
callback_class_args : dict, optional
|
329
|
+
Arguments passed to the callback_class. Default is an empty dictionary.
|
330
|
+
job_schedule : tuple of int, optional
|
331
|
+
Job scheduling scheme, default is (1, 1). First value corresponds
|
332
|
+
to the number of splits that are processed in parallel, the second
|
333
|
+
to the number of angles evaluated in parallel on each split.
|
334
|
+
target_splits : dict, optional
|
335
|
+
Splits for target. Default is an empty dictionary, i.e. no splits.
|
336
|
+
See :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
|
337
|
+
template_splits : dict, optional
|
338
|
+
Splits for template. Default is an empty dictionary, i.e. no splits.
|
339
|
+
See :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
|
340
|
+
pad_target_edges : bool, optional
|
341
|
+
Pad the target boundaries to avoid edge effects.
|
342
|
+
pad_template_filter: bool, optional
|
343
|
+
Whether to pad potential template filters to the full convolution shape.
|
344
|
+
interpolation_order : int, optional
|
345
|
+
Order of spline interpolation for rotations.
|
346
|
+
jobs_per_callback_class : int, optional
|
347
|
+
How many jobs should be processed by a single callback_class instance,
|
348
|
+
if ones is provided.
|
349
|
+
verbose : bool, optional
|
350
|
+
Indicate matching progress.
|
351
|
+
|
352
|
+
Returns
|
353
|
+
-------
|
354
|
+
Optional[Tuple]
|
355
|
+
The merged results from callback_class if provided otherwise None.
|
356
|
+
|
357
|
+
Examples
|
358
|
+
--------
|
359
|
+
All data relevant to template matching will be contained in ``matching_data``, which
|
360
|
+
is a :py:class:`tme.matching_data.MatchingData` instance and can be created like so
|
361
|
+
|
362
|
+
>>> import numpy as np
|
363
|
+
>>> from tme.matching_data import MatchingData
|
364
|
+
>>> from tme.matching_utils import get_rotation_matrices
|
365
|
+
>>> target = np.random.rand(50,40,60)
|
366
|
+
>>> template = target[15:25, 10:20, 30:40]
|
367
|
+
>>> matching_data = MatchingData(target, template)
|
368
|
+
>>> matching_data.rotations = get_rotation_matrices(
|
369
|
+
>>> angular_sampling=60, dim=target.ndim
|
370
|
+
>>> )
|
371
|
+
|
372
|
+
The template matching procedure is determined by ``matching_setup`` and
|
373
|
+
``matching_score``, which are unique to each score. In the following,
|
374
|
+
we will be using the `FLCSphericalMask` score, which is composed of
|
375
|
+
:py:meth:`tme.matching_scores.flcSphericalMask_setup` and
|
376
|
+
:py:meth:`tme.matching_scores.corr_scoring`
|
377
|
+
|
378
|
+
>>> from tme.matching_exhaustive import MATCHING_EXHAUSTIVE_REGISTER
|
379
|
+
>>> funcs = MATCHING_EXHAUSTIVE_REGISTER.get("FLCSphericalMask")
|
380
|
+
>>> matching_setup, matching_score = funcs
|
381
|
+
|
382
|
+
Computed scores are flexibly analyzed by being passed through an analyzer. In the
|
383
|
+
following, we will use :py:class:`tme.analyzer.MaxScoreOverRotations` to
|
384
|
+
aggregate sores over rotations
|
385
|
+
|
386
|
+
>>> from tme.analyzer import MaxScoreOverRotations
|
387
|
+
>>> callback_class = MaxScoreOverRotations
|
388
|
+
>>> callback_class_args = {"score_threshold" : 0}
|
389
|
+
|
390
|
+
In case the entire template matching problem does not fit into memory, we can
|
391
|
+
determine the splitting procedure. In this case, we halv the first axis of the target
|
392
|
+
once. Splitting and ``job_schedule`` is typically computed using
|
393
|
+
:py:meth:`tme.matching_utils.compute_parallelization_schedule`.
|
394
|
+
|
395
|
+
>>> target_splits = {0 : 1}
|
396
|
+
|
397
|
+
Finally, we can perform template matching. Note that the data
|
398
|
+
contained in ``matching_data`` will be destroyed when running the following
|
399
|
+
|
400
|
+
>>> from tme.matching_exhaustive import scan_subsets
|
401
|
+
>>> results = scan_subsets(
|
402
|
+
>>> matching_data=matching_data,
|
403
|
+
>>> matching_score=matching_score,
|
404
|
+
>>> matching_setup=matching_setup,
|
405
|
+
>>> callback_class=callback_class,
|
406
|
+
>>> callback_class_args=callback_class_args,
|
407
|
+
>>> target_splits=target_splits,
|
408
|
+
>>> )
|
409
|
+
|
410
|
+
The ``results`` tuple contains the output of the chosen analyzer.
|
411
|
+
|
412
|
+
See Also
|
413
|
+
--------
|
414
|
+
:py:meth:`tme.matching_utils.compute_parallelization_schedule`
|
415
|
+
"""
|
416
|
+
template_splits = split_shape(matching_data._template.shape, splits=template_splits)
|
417
|
+
target_splits = split_shape(matching_data._target.shape, splits=target_splits)
|
418
|
+
if (len(target_splits) > 1) and not pad_target_edges:
|
419
|
+
warnings.warn(
|
420
|
+
"Target splitting without padding target edges leads to unreliable "
|
421
|
+
"similarity estimates around the split border."
|
422
|
+
)
|
423
|
+
splits = tuple(product(target_splits, template_splits))
|
424
|
+
|
425
|
+
outer_jobs, inner_jobs = job_schedule
|
426
|
+
if hasattr(be, "scan"):
|
427
|
+
corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
|
428
|
+
results = be.scan(
|
429
|
+
matching_data=matching_data,
|
430
|
+
splits=splits,
|
431
|
+
n_jobs=outer_jobs,
|
432
|
+
rotate_mask=matching_score != corr_scoring,
|
433
|
+
callback_class=callback_class,
|
434
|
+
)
|
435
|
+
else:
|
436
|
+
results = Parallel(n_jobs=outer_jobs, verbose=verbose)(
|
437
|
+
[
|
438
|
+
delayed(_wrap_backend(scan))(
|
439
|
+
backend_name=be._backend_name,
|
440
|
+
backend_args=be._backend_args,
|
441
|
+
matching_data=matching_data,
|
442
|
+
matching_score=matching_score,
|
443
|
+
matching_setup=matching_setup,
|
444
|
+
n_jobs=inner_jobs,
|
445
|
+
callback_class=callback_class,
|
446
|
+
callback_class_args=callback_class_args,
|
447
|
+
interpolation_order=interpolation_order,
|
448
|
+
pad_fourier=pad_target_edges,
|
449
|
+
gpu_index=index % outer_jobs,
|
450
|
+
pad_template_filter=pad_template_filter,
|
451
|
+
target_slice=target_split,
|
452
|
+
template_slice=template_split,
|
453
|
+
)
|
454
|
+
for index, (target_split, template_split) in enumerate(splits)
|
455
|
+
]
|
456
|
+
)
|
457
|
+
|
458
|
+
matching_data._free_data()
|
459
|
+
if callback_class is not None:
|
460
|
+
return callback_class.merge(results, **callback_class_args)
|
461
|
+
return None
|
462
|
+
|
463
|
+
|
464
|
+
def register_matching_exhaustive(
|
465
|
+
matching: str,
|
466
|
+
matching_setup: Callable,
|
467
|
+
matching_scoring: Callable,
|
468
|
+
memory_class: MatchingMemoryUsage,
|
469
|
+
) -> None:
|
470
|
+
"""
|
471
|
+
Registers a new matching scheme.
|
472
|
+
|
473
|
+
Parameters
|
474
|
+
----------
|
475
|
+
matching : str
|
476
|
+
Name of the matching method.
|
477
|
+
matching_setup : Callable
|
478
|
+
Corresponding setup function.
|
479
|
+
matching_scoring : Callable
|
480
|
+
Corresponing scoring function.
|
481
|
+
memory_class : MatchingMemoryUsage
|
482
|
+
Child of :py:class:`tme.memory.MatchingMemoryUsage`.
|
483
|
+
|
484
|
+
Raises
|
485
|
+
------
|
486
|
+
ValueError
|
487
|
+
If a function with the name ``matching`` already exists in the registry, or
|
488
|
+
if ``memory_class`` is no child of :py:class:`tme.memory.MatchingMemoryUsage`.
|
489
|
+
"""
|
490
|
+
|
491
|
+
if matching in MATCHING_EXHAUSTIVE_REGISTER:
|
492
|
+
raise ValueError(f"A method with name '{matching}' is already registered.")
|
493
|
+
if not issubclass(memory_class, MatchingMemoryUsage):
|
494
|
+
raise ValueError(f"{memory_class} is not a subclass of {MatchingMemoryUsage}.")
|
495
|
+
|
496
|
+
MATCHING_EXHAUSTIVE_REGISTER[matching] = (matching_setup, matching_scoring)
|
497
|
+
MATCHING_MEMORY_REGISTRY[matching] = memory_class
|