pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
@@ -0,0 +1,497 @@
1
+ """ Implements cross-correlation based template matching using different metrics.
2
+
3
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ import sys
9
+ import warnings
10
+ from math import prod
11
+ from functools import wraps
12
+ from itertools import product
13
+ from typing import Callable, Tuple, Dict, Optional
14
+
15
+ from joblib import Parallel, delayed
16
+ from multiprocessing.managers import SharedMemoryManager
17
+
18
+ from .filters import Compose
19
+ from .backends import backend as be
20
+ from .matching_utils import split_shape
21
+ from .types import CallbackClass, MatchingData
22
+ from .matching_scores import MATCHING_EXHAUSTIVE_REGISTER
23
+ from .memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
24
+
25
+
26
+ def _wrap_backend(func):
27
+ @wraps(func)
28
+ def wrapper(*args, backend_name: str, backend_args: Dict, **kwargs):
29
+ from tme.backends import backend as be
30
+
31
+ be.change_backend(backend_name, **backend_args)
32
+ return func(*args, **kwargs)
33
+
34
+ return wrapper
35
+
36
+
37
+ def _setup_template_filter_apply_target_filter(
38
+ matching_data: MatchingData,
39
+ fast_shape: Tuple[int],
40
+ fast_ft_shape: Tuple[int],
41
+ pad_template_filter: bool = True,
42
+ ):
43
+ target_filter = None
44
+ backend_arr = type(be.zeros((1), dtype=be._float_dtype))
45
+ template_filter = be.full(shape=(1,), fill_value=1, dtype=be._float_dtype)
46
+ if isinstance(matching_data.template_filter, backend_arr):
47
+ template_filter = matching_data.template_filter
48
+
49
+ if isinstance(matching_data.target_filter, backend_arr):
50
+ target_filter = matching_data.target_filter
51
+
52
+ filter_template = isinstance(matching_data.template_filter, Compose)
53
+ filter_target = isinstance(matching_data.target_filter, Compose)
54
+
55
+ # For now assume user-supplied template_filter is correctly padded
56
+ if filter_target is None and target_filter is None:
57
+ return template_filter
58
+
59
+ cmpl_template_shape_full, batch_mask = fast_ft_shape, matching_data._batch_mask
60
+ real_shape = matching_data._batch_shape(fast_shape, batch_mask, keepdims=False)
61
+ cmpl_shape = matching_data._batch_shape(fast_ft_shape, batch_mask, keepdims=True)
62
+
63
+ real_template_shape, cmpl_template_shape = real_shape, cmpl_shape
64
+ cmpl_template_shape_full = matching_data._batch_shape(
65
+ fast_ft_shape, matching_data._target_batch, keepdims=True
66
+ )
67
+ cmpl_target_shape_full = matching_data._batch_shape(
68
+ fast_ft_shape, matching_data._template_batch, keepdims=True
69
+ )
70
+ if filter_template and not pad_template_filter:
71
+ out_shape = matching_data._output_template_shape
72
+ real_template_shape = matching_data._batch_shape(
73
+ out_shape, batch_mask, keepdims=False
74
+ )
75
+ cmpl_template_shape = list(
76
+ matching_data._batch_shape(out_shape, batch_mask, keepdims=True)
77
+ )
78
+ cmpl_template_shape_full = list(out_shape)
79
+ cmpl_template_shape[-1] = cmpl_template_shape[-1] // 2 + 1
80
+ cmpl_template_shape_full[-1] = cmpl_template_shape_full[-1] // 2 + 1
81
+
82
+ # Setup composable filters
83
+ target_temp = be.topleft_pad(matching_data.target, fast_shape)
84
+ target_temp_ft = be.rfftn(target_temp)
85
+ filter_kwargs = {
86
+ "return_real_fourier": True,
87
+ "shape_is_real_fourier": False,
88
+ "data_rfft": target_temp_ft,
89
+ "batch_dimension": matching_data._target_dim,
90
+ }
91
+
92
+ if filter_template:
93
+ template_filter = matching_data.template_filter(
94
+ shape=real_template_shape, **filter_kwargs
95
+ )["data"]
96
+ template_filter_size = int(be.size(template_filter))
97
+
98
+ if template_filter_size == prod(cmpl_template_shape_full):
99
+ cmpl_template_shape = cmpl_template_shape_full
100
+ elif template_filter_size == prod(cmpl_shape):
101
+ cmpl_template_shape = cmpl_shape
102
+ template_filter = be.reshape(template_filter, cmpl_template_shape)
103
+
104
+ if filter_target:
105
+ target_filter = matching_data.target_filter(
106
+ shape=real_shape, weight_type=None, **filter_kwargs
107
+ )["data"]
108
+ if int(be.size(target_filter)) == prod(cmpl_target_shape_full):
109
+ cmpl_shape = cmpl_target_shape_full
110
+
111
+ target_filter = be.reshape(target_filter, cmpl_shape)
112
+ target_temp_ft = be.multiply(target_temp_ft, target_filter, out=target_temp_ft)
113
+
114
+ target_temp = be.irfftn(target_temp_ft, s=target_temp.shape)
115
+ matching_data._target = be.topleft_pad(target_temp, matching_data.target.shape)
116
+
117
+ return be.astype(be.to_backend_array(template_filter), be._float_dtype)
118
+
119
+
120
+ def device_memory_handler(func: Callable):
121
+ """Decorator function providing SharedMemory Handler."""
122
+
123
+ @wraps(func)
124
+ def inner_function(*args, **kwargs):
125
+ return_value = None
126
+ last_type, last_value, last_traceback = sys.exc_info()
127
+ try:
128
+ with SharedMemoryManager() as smh:
129
+ gpu_index = kwargs.pop("gpu_index") if "gpu_index" in kwargs else 0
130
+ with be.set_device(gpu_index):
131
+ return_value = func(shm_handler=smh, *args, **kwargs)
132
+ except Exception:
133
+ last_type, last_value, last_traceback = sys.exc_info()
134
+ finally:
135
+ if last_type is not None:
136
+ raise last_value.with_traceback(last_traceback)
137
+ return return_value
138
+
139
+ return inner_function
140
+
141
+
142
+ @device_memory_handler
143
+ def scan(
144
+ matching_data: MatchingData,
145
+ matching_setup: Callable,
146
+ matching_score: Callable,
147
+ n_jobs: int = 4,
148
+ callback_class: CallbackClass = None,
149
+ callback_class_args: Dict = {},
150
+ pad_fourier: bool = True,
151
+ pad_template_filter: bool = True,
152
+ interpolation_order: int = 3,
153
+ jobs_per_callback_class: int = 8,
154
+ shm_handler=None,
155
+ target_slice=None,
156
+ template_slice=None,
157
+ ) -> Optional[Tuple]:
158
+ """
159
+ Run template matching.
160
+
161
+ .. warning:: ``matching_data`` might be altered or destroyed during computation.
162
+
163
+ Parameters
164
+ ----------
165
+ matching_data : :py:class:`tme.matching_data.MatchingData`
166
+ Template matching data.
167
+ matching_setup : Callable
168
+ Function pointer to setup function.
169
+ matching_score : Callable
170
+ Function pointer to scoring function.
171
+ n_jobs : int, optional
172
+ Number of parallel jobs. Default is 4.
173
+ callback_class : type, optional
174
+ Analyzer class pointer to operate on computed scores.
175
+ callback_class_args : dict, optional
176
+ Arguments passed to the callback_class. Default is an empty dictionary.
177
+ pad_fourier: bool, optional
178
+ Whether to pad target and template to the full convolution shape.
179
+ pad_template_filter: bool, optional
180
+ Whether to pad potential template filters to the full convolution shape.
181
+ interpolation_order : int, optional
182
+ Order of spline interpolation for rotations.
183
+ jobs_per_callback_class : int, optional
184
+ Number of jobs a callback_class instance is shared between, 8 by default.
185
+ shm_handler : type, optional
186
+ Manager for shared memory objects, None by default.
187
+
188
+ Returns
189
+ -------
190
+ Optional[Tuple]
191
+ The merged results from callback_class if provided otherwise None.
192
+
193
+ Examples
194
+ --------
195
+ Schematically, :py:meth:`scan` is identical to :py:meth:`scan_subsets`,
196
+ with the distinction that the objects contained in ``matching_data`` are not
197
+ split and the search is only parallelized over angles.
198
+ Assuming you have followed the example in :py:meth:`scan_subsets`, :py:meth:`scan`
199
+ can be invoked like so
200
+
201
+ >>> from tme.matching_exhaustive import scan
202
+ >>> results = scan(
203
+ >>> matching_data=matching_data,
204
+ >>> matching_score=matching_score,
205
+ >>> matching_setup=matching_setup,
206
+ >>> callback_class=callback_class,
207
+ >>> callback_class_args=callback_class_args,
208
+ >>> )
209
+
210
+ """
211
+ matching_data = matching_data.subset_by_slice(
212
+ target_slice=target_slice,
213
+ template_slice=template_slice,
214
+ target_pad=matching_data.target_padding(pad_target=pad_fourier),
215
+ )
216
+
217
+ matching_data.to_backend()
218
+ template_shape = matching_data._batch_shape(
219
+ matching_data.template.shape, matching_data._template_batch
220
+ )
221
+ conv, fwd, inv, shift = matching_data.fourier_padding(pad_fourier=False)
222
+
223
+ template_filter = _setup_template_filter_apply_target_filter(
224
+ matching_data=matching_data,
225
+ fast_shape=fwd,
226
+ fast_ft_shape=inv,
227
+ pad_template_filter=pad_template_filter,
228
+ )
229
+
230
+ default_callback_args = {
231
+ "shape": fwd,
232
+ "offset": matching_data._translation_offset,
233
+ "fourier_shift": shift,
234
+ "fast_shape": fwd,
235
+ "targetshape": matching_data._output_shape,
236
+ "templateshape": template_shape,
237
+ "convolution_shape": conv,
238
+ "thread_safe": n_jobs > 1,
239
+ "convolution_mode": "valid" if pad_fourier else "same",
240
+ "shm_handler": shm_handler,
241
+ "only_unique_rotations": True,
242
+ "aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
243
+ "n_rotations": matching_data.rotations.shape[0],
244
+ }
245
+ default_callback_args.update(callback_class_args)
246
+
247
+ setup = matching_setup(
248
+ matching_data=matching_data,
249
+ template_filter=template_filter,
250
+ fast_shape=fwd,
251
+ fast_ft_shape=inv,
252
+ shm_handler=shm_handler,
253
+ )
254
+ setup["interpolation_order"] = interpolation_order
255
+ setup["template_filter"] = be.to_sharedarr(template_filter, shm_handler)
256
+
257
+ matching_data._free_data()
258
+ be.free_cache()
259
+
260
+ # Some analyzers cannot be shared across processes
261
+ if not getattr(callback_class, "is_shareable", False):
262
+ jobs_per_callback_class = 1
263
+
264
+ n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
265
+ callback_classes = [
266
+ callback_class(**default_callback_args) if callback_class else None
267
+ for _ in range(n_callback_classes)
268
+ ]
269
+ ret = Parallel(n_jobs=n_jobs)(
270
+ delayed(_wrap_backend(matching_score))(
271
+ backend_name=be._backend_name,
272
+ backend_args=be._backend_args,
273
+ rotations=rotation,
274
+ callback=callback_classes[index % n_callback_classes],
275
+ **setup,
276
+ )
277
+ for index, rotation in enumerate(matching_data._split_rotations_on_jobs(n_jobs))
278
+ )
279
+
280
+ # TODO: Make sure peak callers are thread safe to begin with
281
+ if not getattr(callback_class, "is_shareable", False):
282
+ callback_classes = ret
283
+
284
+ callbacks = [
285
+ tuple(callback._postprocess(**default_callback_args))
286
+ for callback in callback_classes
287
+ if callback
288
+ ]
289
+ be.free_cache()
290
+
291
+ if callback_class:
292
+ ret = callback_class.merge(callbacks, **default_callback_args)
293
+ return ret
294
+
295
+
296
+ def scan_subsets(
297
+ matching_data: MatchingData,
298
+ matching_score: Callable,
299
+ matching_setup: Callable,
300
+ callback_class: CallbackClass = None,
301
+ callback_class_args: Dict = {},
302
+ job_schedule: Tuple[int] = (1, 1),
303
+ target_splits: Dict = {},
304
+ template_splits: Dict = {},
305
+ pad_target_edges: bool = False,
306
+ pad_template_filter: bool = True,
307
+ interpolation_order: int = 3,
308
+ jobs_per_callback_class: int = 8,
309
+ backend_name: str = None,
310
+ backend_args: Dict = {},
311
+ verbose: bool = False,
312
+ **kwargs,
313
+ ) -> Optional[Tuple]:
314
+ """
315
+ Wrapper around :py:meth:`scan` that supports matching on splits
316
+ of ``matching_data``.
317
+
318
+ Parameters
319
+ ----------
320
+ matching_data : :py:class:`tme.matching_data.MatchingData`
321
+ MatchingData instance containing relevant data.
322
+ matching_setup : type
323
+ Function pointer to setup function.
324
+ matching_score : type
325
+ Function pointer to scoring function.
326
+ callback_class : type, optional
327
+ Analyzer class pointer to operate on computed scores.
328
+ callback_class_args : dict, optional
329
+ Arguments passed to the callback_class. Default is an empty dictionary.
330
+ job_schedule : tuple of int, optional
331
+ Job scheduling scheme, default is (1, 1). First value corresponds
332
+ to the number of splits that are processed in parallel, the second
333
+ to the number of angles evaluated in parallel on each split.
334
+ target_splits : dict, optional
335
+ Splits for target. Default is an empty dictionary, i.e. no splits.
336
+ See :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
337
+ template_splits : dict, optional
338
+ Splits for template. Default is an empty dictionary, i.e. no splits.
339
+ See :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
340
+ pad_target_edges : bool, optional
341
+ Pad the target boundaries to avoid edge effects.
342
+ pad_template_filter: bool, optional
343
+ Whether to pad potential template filters to the full convolution shape.
344
+ interpolation_order : int, optional
345
+ Order of spline interpolation for rotations.
346
+ jobs_per_callback_class : int, optional
347
+ How many jobs should be processed by a single callback_class instance,
348
+ if ones is provided.
349
+ verbose : bool, optional
350
+ Indicate matching progress.
351
+
352
+ Returns
353
+ -------
354
+ Optional[Tuple]
355
+ The merged results from callback_class if provided otherwise None.
356
+
357
+ Examples
358
+ --------
359
+ All data relevant to template matching will be contained in ``matching_data``, which
360
+ is a :py:class:`tme.matching_data.MatchingData` instance and can be created like so
361
+
362
+ >>> import numpy as np
363
+ >>> from tme.matching_data import MatchingData
364
+ >>> from tme.matching_utils import get_rotation_matrices
365
+ >>> target = np.random.rand(50,40,60)
366
+ >>> template = target[15:25, 10:20, 30:40]
367
+ >>> matching_data = MatchingData(target, template)
368
+ >>> matching_data.rotations = get_rotation_matrices(
369
+ >>> angular_sampling=60, dim=target.ndim
370
+ >>> )
371
+
372
+ The template matching procedure is determined by ``matching_setup`` and
373
+ ``matching_score``, which are unique to each score. In the following,
374
+ we will be using the `FLCSphericalMask` score, which is composed of
375
+ :py:meth:`tme.matching_scores.flcSphericalMask_setup` and
376
+ :py:meth:`tme.matching_scores.corr_scoring`
377
+
378
+ >>> from tme.matching_exhaustive import MATCHING_EXHAUSTIVE_REGISTER
379
+ >>> funcs = MATCHING_EXHAUSTIVE_REGISTER.get("FLCSphericalMask")
380
+ >>> matching_setup, matching_score = funcs
381
+
382
+ Computed scores are flexibly analyzed by being passed through an analyzer. In the
383
+ following, we will use :py:class:`tme.analyzer.MaxScoreOverRotations` to
384
+ aggregate sores over rotations
385
+
386
+ >>> from tme.analyzer import MaxScoreOverRotations
387
+ >>> callback_class = MaxScoreOverRotations
388
+ >>> callback_class_args = {"score_threshold" : 0}
389
+
390
+ In case the entire template matching problem does not fit into memory, we can
391
+ determine the splitting procedure. In this case, we halv the first axis of the target
392
+ once. Splitting and ``job_schedule`` is typically computed using
393
+ :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
394
+
395
+ >>> target_splits = {0 : 1}
396
+
397
+ Finally, we can perform template matching. Note that the data
398
+ contained in ``matching_data`` will be destroyed when running the following
399
+
400
+ >>> from tme.matching_exhaustive import scan_subsets
401
+ >>> results = scan_subsets(
402
+ >>> matching_data=matching_data,
403
+ >>> matching_score=matching_score,
404
+ >>> matching_setup=matching_setup,
405
+ >>> callback_class=callback_class,
406
+ >>> callback_class_args=callback_class_args,
407
+ >>> target_splits=target_splits,
408
+ >>> )
409
+
410
+ The ``results`` tuple contains the output of the chosen analyzer.
411
+
412
+ See Also
413
+ --------
414
+ :py:meth:`tme.matching_utils.compute_parallelization_schedule`
415
+ """
416
+ template_splits = split_shape(matching_data._template.shape, splits=template_splits)
417
+ target_splits = split_shape(matching_data._target.shape, splits=target_splits)
418
+ if (len(target_splits) > 1) and not pad_target_edges:
419
+ warnings.warn(
420
+ "Target splitting without padding target edges leads to unreliable "
421
+ "similarity estimates around the split border."
422
+ )
423
+ splits = tuple(product(target_splits, template_splits))
424
+
425
+ outer_jobs, inner_jobs = job_schedule
426
+ if hasattr(be, "scan"):
427
+ corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
428
+ results = be.scan(
429
+ matching_data=matching_data,
430
+ splits=splits,
431
+ n_jobs=outer_jobs,
432
+ rotate_mask=matching_score != corr_scoring,
433
+ callback_class=callback_class,
434
+ )
435
+ else:
436
+ results = Parallel(n_jobs=outer_jobs, verbose=verbose)(
437
+ [
438
+ delayed(_wrap_backend(scan))(
439
+ backend_name=be._backend_name,
440
+ backend_args=be._backend_args,
441
+ matching_data=matching_data,
442
+ matching_score=matching_score,
443
+ matching_setup=matching_setup,
444
+ n_jobs=inner_jobs,
445
+ callback_class=callback_class,
446
+ callback_class_args=callback_class_args,
447
+ interpolation_order=interpolation_order,
448
+ pad_fourier=pad_target_edges,
449
+ gpu_index=index % outer_jobs,
450
+ pad_template_filter=pad_template_filter,
451
+ target_slice=target_split,
452
+ template_slice=template_split,
453
+ )
454
+ for index, (target_split, template_split) in enumerate(splits)
455
+ ]
456
+ )
457
+
458
+ matching_data._free_data()
459
+ if callback_class is not None:
460
+ return callback_class.merge(results, **callback_class_args)
461
+ return None
462
+
463
+
464
+ def register_matching_exhaustive(
465
+ matching: str,
466
+ matching_setup: Callable,
467
+ matching_scoring: Callable,
468
+ memory_class: MatchingMemoryUsage,
469
+ ) -> None:
470
+ """
471
+ Registers a new matching scheme.
472
+
473
+ Parameters
474
+ ----------
475
+ matching : str
476
+ Name of the matching method.
477
+ matching_setup : Callable
478
+ Corresponding setup function.
479
+ matching_scoring : Callable
480
+ Corresponing scoring function.
481
+ memory_class : MatchingMemoryUsage
482
+ Child of :py:class:`tme.memory.MatchingMemoryUsage`.
483
+
484
+ Raises
485
+ ------
486
+ ValueError
487
+ If a function with the name ``matching`` already exists in the registry, or
488
+ if ``memory_class`` is no child of :py:class:`tme.memory.MatchingMemoryUsage`.
489
+ """
490
+
491
+ if matching in MATCHING_EXHAUSTIVE_REGISTER:
492
+ raise ValueError(f"A method with name '{matching}' is already registered.")
493
+ if not issubclass(memory_class, MatchingMemoryUsage):
494
+ raise ValueError(f"{memory_class} is not a subclass of {MatchingMemoryUsage}.")
495
+
496
+ MATCHING_EXHAUSTIVE_REGISTER[matching] = (matching_setup, matching_scoring)
497
+ MATCHING_MEMORY_REGISTRY[matching] = memory_class