pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/match_template.py +147 -93
  2. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/postprocess.py +67 -26
  3. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +175 -85
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +20 -13
  8. scripts/match_template.py +147 -93
  9. scripts/match_template_filters.py +154 -95
  10. scripts/postprocess.py +67 -26
  11. scripts/preprocessor_gui.py +175 -85
  12. scripts/refine_matches.py +265 -61
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +451 -809
  16. tme/backends/__init__.py +40 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +111 -223
  19. tme/backends/jax_backend.py +214 -150
  20. tme/backends/matching_backend.py +445 -384
  21. tme/backends/mlx_backend.py +32 -59
  22. tme/backends/npfftw_backend.py +239 -507
  23. tme/backends/pytorch_backend.py +21 -145
  24. tme/density.py +233 -363
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +322 -285
  27. tme/matching_exhaustive.py +172 -1493
  28. tme/matching_optimization.py +143 -106
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +280 -386
  31. tme/memory.py +377 -0
  32. tme/orientations.py +52 -12
  33. tme/parser.py +3 -4
  34. tme/preprocessing/_utils.py +61 -32
  35. tme/preprocessing/compose.py +7 -3
  36. tme/preprocessing/frequency_filters.py +49 -39
  37. tme/preprocessing/tilt_series.py +34 -40
  38. tme/preprocessor.py +560 -526
  39. tme/structure.py +491 -188
  40. tme/types.py +5 -3
  41. pytme-0.2.1.dist-info/METADATA +0 -73
  42. pytme-0.2.1.dist-info/RECORD +0 -73
  43. tme/helpers.py +0 -881
  44. tme/matching_constrained.py +0 -195
  45. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  46. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  47. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  48. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  49. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
@@ -4,1300 +4,57 @@
4
4
 
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
- import os
8
7
  import sys
9
8
  import warnings
10
- from itertools import product
11
- from typing import Callable, Tuple, Dict
9
+ import traceback
12
10
  from functools import wraps
11
+ from itertools import product
12
+ from typing import Callable, Tuple, Dict, Optional
13
+
13
14
  from joblib import Parallel, delayed
14
15
  from multiprocessing.managers import SharedMemoryManager
15
16
 
16
- import numpy as np
17
- from scipy.ndimage import laplace
18
-
19
- from .analyzer import MaxScoreOverRotations
20
- from .matching_utils import (
21
- handle_traceback,
22
- split_numpy_array_slices,
23
- conditional_execute,
24
- )
25
- from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
17
+ from .backends import backend as be
26
18
  from .preprocessing import Compose
27
- from .matching_data import MatchingData
28
- from .backends import backend
29
- from .types import NDArray, CallbackClass
30
-
31
- os.environ["MKL_NUM_THREADS"] = "1"
32
- os.environ["OMP_NUM_THREADS"] = "1"
33
- os.environ["PYFFTW_NUM_THREADS"] = "1"
34
- os.environ["OPENBLAS_NUM_THREADS"] = "1"
35
-
36
-
37
- def _run_inner(backend_name, backend_args, **kwargs):
38
- from tme.backends import backend
19
+ from .matching_utils import split_shape
20
+ from .matching_scores import MATCHING_EXHAUSTIVE_REGISTER
21
+ from .types import CallbackClass, MatchingData
22
+ from .memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
39
23
 
40
- backend.change_backend(backend_name, **backend_args)
41
- return scan(**kwargs)
42
24
 
43
-
44
- def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
25
+ def _handle_traceback(last_type, last_value, last_traceback):
45
26
  """
46
- Standardizes the values in in template by subtracting the mean and dividing by the
47
- standard deviation based on the elements in mask. Subsequently, the template is
48
- multiplied by the mask.
27
+ Handle sys.exc_info().
49
28
 
50
29
  Parameters
51
30
  ----------
52
- template : NDArray
53
- The data array to be normalized. This array is modified in-place.
54
- mask : NDArray
55
- A boolean array of the same shape as `template`. True values indicate the positions in `template`
56
- to consider for normalization.
57
- mask_intensity : float
58
- Mask intensity used to compute expectations.
59
-
60
- References
61
- ----------
62
- .. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
63
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
64
-
65
- Returns
66
- -------
67
- None
68
- This function modifies `template` in-place and does not return any value.
69
- """
70
- masked_mean = backend.sum(backend.multiply(template, mask))
71
- masked_mean = backend.divide(masked_mean, mask_intensity)
72
- masked_std = backend.sum(backend.multiply(backend.square(template), mask))
73
- masked_std = backend.subtract(
74
- masked_std / mask_intensity, backend.square(masked_mean)
75
- )
76
- masked_std = backend.sqrt(backend.maximum(masked_std, 0))
77
-
78
- backend.subtract(template, masked_mean, out=template)
79
- backend.divide(template, masked_std, out=template)
80
- backend.multiply(template, mask, out=template)
81
- return None
82
-
83
-
84
- def _normalize_under_mask_overflow_safe(
85
- template: NDArray, mask: NDArray, mask_intensity
86
- ) -> None:
87
- _template = backend.astype(template, backend._overflow_safe_dtype)
88
- _mask = backend.astype(mask, backend._overflow_safe_dtype)
89
- normalize_under_mask(template=_template, mask=_mask, mask_intensity=mask_intensity)
90
- template[:] = backend.astype(_template, template.dtype)
91
- return None
92
-
93
-
94
- def apply_filter(ft_template, template_filter):
95
- # This is an approximation to applying the mask, irfftn, normalize, rfftn
96
- std_before = backend.std(ft_template)
97
- backend.multiply(ft_template, template_filter, out=ft_template)
98
- backend.multiply(
99
- ft_template, std_before / backend.std(ft_template), out=ft_template
100
- )
101
-
102
-
103
- def cc_setup(
104
- rfftn: Callable,
105
- irfftn: Callable,
106
- template: NDArray,
107
- target: NDArray,
108
- fast_shape: Tuple[int],
109
- fast_ft_shape: Tuple[int],
110
- shared_memory_handler: Callable,
111
- callback_class: Callable,
112
- callback_class_args: Dict,
113
- **kwargs,
114
- ) -> Dict:
115
- """
116
- Setup to compute the cross-correlation between a target f and template g
117
- defined as:
118
-
119
- .. math::
120
-
121
- \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
122
-
123
-
124
- See Also
125
- --------
126
- :py:meth:`corr_scoring`
127
- :py:class:`tme.matching_optimization.CrossCorrelation`
128
- """
129
- real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
130
- target_pad = backend.topleft_pad(target, fast_shape)
131
- target_pad_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
132
-
133
- rfftn(target_pad, target_pad_ft)
134
-
135
- target_ft_out = backend.arr_to_sharedarr(
136
- arr=target_pad_ft, shared_memory_handler=shared_memory_handler
137
- )
138
-
139
- template_out = backend.arr_to_sharedarr(
140
- arr=template, shared_memory_handler=shared_memory_handler
141
- )
142
- inv_denominator_buffer = backend.arr_to_sharedarr(
143
- arr=backend.preallocate_array(1, real_dtype) + 1,
144
- shared_memory_handler=shared_memory_handler,
145
- )
146
- numerator_buffer = backend.arr_to_sharedarr(
147
- arr=backend.preallocate_array(1, real_dtype),
148
- shared_memory_handler=shared_memory_handler,
149
- )
150
-
151
- target_ft_tuple = (target_ft_out, fast_ft_shape, complex_dtype)
152
- template_tuple = (template_out, template.shape, template.dtype)
153
-
154
- inv_denominator_tuple = (inv_denominator_buffer, (1,), real_dtype)
155
- numerator_tuple = (numerator_buffer, (1,), real_dtype)
156
-
157
- ret = {
158
- "template": template_tuple,
159
- "ft_target": target_ft_tuple,
160
- "inv_denominator": inv_denominator_tuple,
161
- "numerator": numerator_tuple,
162
- "targetshape": target.shape,
163
- "templateshape": template.shape,
164
- "fast_shape": fast_shape,
165
- "fast_ft_shape": fast_ft_shape,
166
- "callback_class": callback_class,
167
- "callback_class_args": callback_class_args,
168
- "use_memmap": kwargs.get("use_memmap", False),
169
- "temp_dir": kwargs.get("temp_dir", None),
170
- }
171
-
172
- return ret
173
-
174
-
175
- def lcc_setup(**kwargs) -> Dict:
176
- """
177
- Setup to compute the cross-correlation between a laplace transformed target f
178
- and laplace transformed template g defined as:
179
-
180
- .. math::
181
-
182
- \\mathcal{F}^{-1}(\\mathcal{F}(\\nabla^{2}f) \\cdot \\mathcal{F}(\\nabla^{2} g)^*)
183
-
184
-
185
- See Also
186
- --------
187
- :py:meth:`corr_scoring`
188
- :py:class:`tme.matching_optimization.LaplaceCrossCorrelation`
189
- """
190
- kwargs["target"] = laplace(kwargs["target"], mode="wrap")
191
- kwargs["template"] = laplace(kwargs["template"], mode="wrap")
192
- return cc_setup(**kwargs)
193
-
194
-
195
- def corr_setup(
196
- rfftn: Callable,
197
- irfftn: Callable,
198
- template: NDArray,
199
- template_mask: NDArray,
200
- target: NDArray,
201
- fast_shape: Tuple[int],
202
- fast_ft_shape: Tuple[int],
203
- shared_memory_handler: Callable,
204
- callback_class: Callable,
205
- callback_class_args: Dict,
206
- **kwargs,
207
- ) -> Dict:
208
- """
209
- Setup to compute a normalized cross-correlation between a target f and a template g.
210
-
211
- .. math::
212
-
213
- \\frac{CC(f,g) - \\overline{g} \\cdot CC(f, m)}
214
- {(CC(f^2, m) - \\frac{CC(f, m)^2}{N_g}) \\cdot \\sigma_{g}}
215
-
216
- Where:
217
-
218
- .. math::
219
-
220
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
221
-
222
- and m is a mask with the same dimension as the template filled with ones.
223
-
224
- References
225
- ----------
226
- .. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light
227
- and Magic.
228
-
229
- See Also
230
- --------
231
- :py:meth:`corr_scoring`
232
- :py:class:`tme.matching_optimization.NormalizedCrossCorrelation`.
233
- """
234
- real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
235
- target_pad = backend.topleft_pad(target, fast_shape)
236
-
237
- # The exact composition of the denominator is debatable
238
- # scikit-image match_template multiplies the running sum of the target
239
- # with a scaling factor derived from the template. This is probably appropriate
240
- # in pattern matching situations where the template exists in the target
241
- window_template = backend.topleft_pad(template_mask, fast_shape)
242
- ft_window_template = backend.preallocate_array(fast_ft_shape, complex_dtype)
243
- rfftn(window_template, ft_window_template)
244
- window_template = None
245
-
246
- # Target and squared target window sums
247
- ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
248
- ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
249
- denominator = backend.preallocate_array(fast_shape, real_dtype)
250
- target_window_sum = backend.preallocate_array(fast_shape, real_dtype)
251
- rfftn(target_pad, ft_target)
252
-
253
- rfftn(backend.square(target_pad), ft_target2)
254
- backend.multiply(ft_target2, ft_window_template, out=ft_target2)
255
- irfftn(ft_target2, denominator)
256
-
257
- backend.multiply(ft_target, ft_window_template, out=ft_window_template)
258
- irfftn(ft_window_template, target_window_sum)
259
-
260
- target_pad, ft_target2, ft_window_template = None, None, None
261
-
262
- # Normalizing constants
263
- n_observations = backend.sum(template_mask)
264
- template_mean = backend.sum(backend.multiply(template, template_mask))
265
- template_mean = backend.divide(template_mean, n_observations)
266
- template_ssd = backend.sum(
267
- backend.square(
268
- backend.multiply(backend.subtract(template, template_mean), template_mask)
269
- )
270
- )
271
- template_volume = np.prod(tuple(int(x) for x in template.shape))
272
- backend.multiply(template, template_mask, out=template)
273
-
274
- # Final numerator is score - numerator
275
- numerator = backend.multiply(target_window_sum, template_mean)
276
-
277
- # Compute denominator
278
- backend.multiply(target_window_sum, target_window_sum, out=target_window_sum)
279
- backend.divide(target_window_sum, template_volume, out=target_window_sum)
280
-
281
- backend.subtract(denominator, target_window_sum, out=denominator)
282
- backend.multiply(denominator, template_ssd, out=denominator)
283
- backend.maximum(denominator, 0, out=denominator)
284
- backend.sqrt(denominator, out=denominator)
285
- target_window_sum = None
286
-
287
- # Invert denominator to compute final score as product
288
- denominator_mask = denominator > backend.eps(denominator.dtype)
289
- inv_denominator = backend.preallocate_array(fast_shape, real_dtype)
290
- inv_denominator[denominator_mask] = 1 / denominator[denominator_mask]
291
-
292
- # Convert arrays used in subsequent fitting to SharedMemory objects
293
- template_buffer = backend.arr_to_sharedarr(
294
- arr=template, shared_memory_handler=shared_memory_handler
295
- )
296
- target_ft_buffer = backend.arr_to_sharedarr(
297
- arr=ft_target, shared_memory_handler=shared_memory_handler
298
- )
299
- inv_denominator_buffer = backend.arr_to_sharedarr(
300
- arr=inv_denominator, shared_memory_handler=shared_memory_handler
301
- )
302
- numerator_buffer = backend.arr_to_sharedarr(
303
- arr=numerator, shared_memory_handler=shared_memory_handler
304
- )
305
-
306
- template_tuple = (template_buffer, template.shape, template.dtype)
307
- target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
308
-
309
- inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
310
- numerator_tuple = (numerator_buffer, fast_shape, real_dtype)
311
-
312
- ft_target, inv_denominator, numerator = None, None, None
313
-
314
- ret = {
315
- "template": template_tuple,
316
- "ft_target": target_ft_tuple,
317
- "inv_denominator": inv_denominator_tuple,
318
- "numerator": numerator_tuple,
319
- "targetshape": target.shape,
320
- "templateshape": template.shape,
321
- "fast_shape": fast_shape,
322
- "fast_ft_shape": fast_ft_shape,
323
- "callback_class": callback_class,
324
- "callback_class_args": callback_class_args,
325
- "template_mean": kwargs.get("template_mean", template_mean),
326
- }
327
-
328
- return ret
329
-
330
-
331
- def cam_setup(**kwargs):
332
- """
333
- Setup to compute a normalized cross-correlation between a target f and a template g
334
- over their means. In practice this can be expressed like the cross-correlation
335
- CORR defined in :py:meth:`corr_scoring`, so that:
336
-
337
- Notes
338
- -----
339
-
340
- .. math::
341
-
342
- \\text{CORR}(f-\\overline{f}, g - \\overline{g})
343
-
344
- Where
345
-
346
- .. math::
347
-
348
- \\overline{f}, \\overline{g}
349
-
350
- are the mean of f and g respectively.
351
-
352
- References
353
- ----------
354
- .. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light
355
- and Magic.
356
-
357
- See Also
358
- --------
359
- :py:meth:`corr_scoring`.
360
- :py:class:`tme.matching_optimization.NormalizedCrossCorrelationMean`.
361
- """
362
- kwargs["template"] = kwargs["template"] - kwargs["template"].mean()
363
- return corr_setup(**kwargs)
364
-
365
-
366
- def flc_setup(
367
- rfftn: Callable,
368
- irfftn: Callable,
369
- template: NDArray,
370
- template_mask: NDArray,
371
- target: NDArray,
372
- fast_shape: Tuple[int],
373
- fast_ft_shape: Tuple[int],
374
- shared_memory_handler: Callable,
375
- callback_class: Callable,
376
- callback_class_args: Dict,
377
- **kwargs,
378
- ) -> Dict:
379
- """
380
- Setup to compute a normalized cross-correlation score of a target f a template g
381
- and a mask m:
382
-
383
- .. math::
384
-
385
- \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
386
- {N_m * \\sqrt{
387
- \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
388
- }
389
-
390
- Where:
391
-
392
- .. math::
393
-
394
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
31
+ last_type : type
32
+ The type of the last exception.
33
+ last_value :
34
+ The value of the last exception.
35
+ last_traceback : traceback
36
+ Traceback call stack at the point where the Exception occured.
395
37
 
396
- and Nm is the number of voxels within the template mask m.
397
-
398
- References
399
- ----------
400
- .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
401
- Microsc. Microanal. 26, 2516 (2020)
402
- .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
403
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
404
-
405
- See Also
406
- --------
407
- :py:meth:`flc_scoring`
408
- """
409
- target_pad = backend.topleft_pad(target, fast_shape)
410
-
411
- # Target and squared target window sums
412
- ft_target = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
413
- ft_target2 = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
414
- rfftn(target_pad, ft_target)
415
- backend.square(target_pad, out=target_pad)
416
- rfftn(target_pad, ft_target2)
417
-
418
- # Convert arrays used in subsequent fitting to SharedMemory objects
419
- ft_target = backend.arr_to_sharedarr(
420
- arr=ft_target, shared_memory_handler=shared_memory_handler
421
- )
422
- ft_target2 = backend.arr_to_sharedarr(
423
- arr=ft_target2, shared_memory_handler=shared_memory_handler
424
- )
425
-
426
- normalize_under_mask(
427
- template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
428
- )
429
-
430
- template_buffer = backend.arr_to_sharedarr(
431
- arr=template, shared_memory_handler=shared_memory_handler
432
- )
433
- template_mask_buffer = backend.arr_to_sharedarr(
434
- arr=template_mask, shared_memory_handler=shared_memory_handler
435
- )
436
-
437
- template_tuple = (template_buffer, template.shape, template.dtype)
438
- template_mask_tuple = (
439
- template_mask_buffer,
440
- template_mask.shape,
441
- template_mask.dtype,
442
- )
443
-
444
- target_ft_tuple = (ft_target, fast_ft_shape, backend._complex_dtype)
445
- target_ft2_tuple = (ft_target2, fast_ft_shape, backend._complex_dtype)
446
-
447
- ret = {
448
- "template": template_tuple,
449
- "template_mask": template_mask_tuple,
450
- "ft_target": target_ft_tuple,
451
- "ft_target2": target_ft2_tuple,
452
- "targetshape": target.shape,
453
- "templateshape": template.shape,
454
- "fast_shape": fast_shape,
455
- "fast_ft_shape": fast_ft_shape,
456
- "callback_class": callback_class,
457
- "callback_class_args": callback_class_args,
458
- }
459
-
460
- return ret
461
-
462
-
463
- def flcSphericalMask_setup(
464
- rfftn: Callable,
465
- irfftn: Callable,
466
- template: NDArray,
467
- template_mask: NDArray,
468
- target: NDArray,
469
- fast_shape: Tuple[int],
470
- fast_ft_shape: Tuple[int],
471
- shared_memory_handler: Callable,
472
- callback_class: Callable,
473
- callback_class_args: Dict,
474
- **kwargs,
475
- ) -> Dict:
476
- """
477
- Like :py:meth:`flc_setup` but for rotation invariant masks. In such cases
478
- the score can be computed quicker by not computing the fourier transforms
479
- of the mask for each rotation.
480
-
481
- .. math::
482
-
483
- \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
484
- {N_m * \\sqrt{
485
- \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
486
- }
487
-
488
- Where:
489
-
490
- .. math::
491
-
492
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
493
-
494
- and Nm is the number of voxels within the template mask m.
495
-
496
- References
497
- ----------
498
- .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
499
- Microsc. Microanal. 26, 2516 (2020)
500
- .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
501
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
502
-
503
- See Also
504
- --------
505
- :py:meth:`corr_scoring`
506
- """
507
- real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
508
- target_pad = backend.topleft_pad(target, fast_shape)
509
-
510
- # Target and squared target window sums
511
- ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
512
- ft_template_mask = backend.preallocate_array(fast_ft_shape, complex_dtype)
513
- ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
514
-
515
- temp = backend.preallocate_array(fast_shape, real_dtype)
516
- temp2 = backend.preallocate_array(fast_shape, real_dtype)
517
- numerator = backend.preallocate_array(1, real_dtype)
518
-
519
- n_observations, norm_func = backend.sum(template_mask), normalize_under_mask
520
- if backend.datatype_bytes(template_mask.dtype) == 2:
521
- norm_func = _normalize_under_mask_overflow_safe
522
- n_observations = backend.sum(
523
- backend.astype(template_mask, backend._overflow_safe_dtype)
524
- )
525
-
526
- template_mask_pad = backend.topleft_pad(template_mask, fast_shape)
527
- rfftn(template_mask_pad, ft_template_mask)
528
-
529
- # Denominator E(X^2) - E(X)^2
530
- rfftn(backend.square(target_pad), ft_target)
531
- backend.multiply(ft_target, ft_template_mask, out=ft_temp)
532
- irfftn(ft_temp, temp2)
533
- backend.divide(temp2, n_observations, out=temp2)
534
-
535
- rfftn(target_pad, ft_target)
536
- backend.multiply(ft_target, ft_template_mask, out=ft_temp)
537
- irfftn(ft_temp, temp)
538
- backend.divide(temp, n_observations, out=temp)
539
- backend.square(temp, out=temp)
540
-
541
- backend.subtract(temp2, temp, out=temp)
542
-
543
- backend.maximum(temp, 0.0, out=temp)
544
- backend.sqrt(temp, out=temp)
545
- backend.multiply(temp, n_observations, out=temp)
546
-
547
- backend.fill(temp2, 0)
548
- nonzero_indices = temp > backend.eps(real_dtype)
549
- temp2[nonzero_indices] = 1 / temp[nonzero_indices]
550
-
551
- norm_func(template=template, mask=template_mask, mask_intensity=n_observations)
552
-
553
- template_buffer = backend.arr_to_sharedarr(
554
- arr=template, shared_memory_handler=shared_memory_handler
555
- )
556
- template_mask_buffer = backend.arr_to_sharedarr(
557
- arr=template_mask, shared_memory_handler=shared_memory_handler
558
- )
559
- target_ft_buffer = backend.arr_to_sharedarr(
560
- arr=ft_target, shared_memory_handler=shared_memory_handler
561
- )
562
- inv_denominator_buffer = backend.arr_to_sharedarr(
563
- arr=temp2, shared_memory_handler=shared_memory_handler
564
- )
565
- numerator_buffer = backend.arr_to_sharedarr(
566
- arr=numerator, shared_memory_handler=shared_memory_handler
567
- )
568
-
569
- template_tuple = (template_buffer, template.shape, template.dtype)
570
- template_mask_tuple = (template_mask_buffer, template.shape, template_mask.dtype)
571
- target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
572
-
573
- inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
574
- numerator_tuple = (numerator_buffer, (1,), real_dtype)
575
-
576
- ret = {
577
- "template": template_tuple,
578
- "template_mask": template_mask_tuple,
579
- "ft_target": target_ft_tuple,
580
- "inv_denominator": inv_denominator_tuple,
581
- "numerator": numerator_tuple,
582
- "targetshape": target.shape,
583
- "templateshape": template.shape,
584
- "fast_shape": fast_shape,
585
- "fast_ft_shape": fast_ft_shape,
586
- "callback_class": callback_class,
587
- "callback_class_args": callback_class_args,
588
- }
589
-
590
- return ret
591
-
592
-
593
- def mcc_setup(
594
- rfftn: Callable,
595
- irfftn: Callable,
596
- template: NDArray,
597
- template_mask: NDArray,
598
- target: NDArray,
599
- target_mask: NDArray,
600
- fast_shape: Tuple[int],
601
- fast_ft_shape: Tuple[int],
602
- shared_memory_handler: Callable,
603
- callback_class: Callable,
604
- callback_class_args: Dict,
605
- **kwargs,
606
- ) -> Dict:
607
- """
608
- Setup to compute a normalized cross-correlation score with masks
609
- for both template and target.
610
-
611
- .. math::
612
-
613
- \\frac{
614
- CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
615
- }{
616
- \\sqrt{
617
- (CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
618
- (CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
619
- }
620
- }
621
-
622
- Where:
623
-
624
- .. math::
625
-
626
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
627
-
628
-
629
- References
630
- ----------
631
- .. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
632
-
633
- See Also
634
- --------
635
- :py:meth:`mcc_scoring`
636
- :py:class:`tme.matching_optimization.MaskedCrossCorrelation`
637
- """
638
- real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
639
- target = backend.multiply(target, target_mask > 0, out=target)
640
-
641
- target_pad = backend.topleft_pad(target, fast_shape)
642
- target_mask_pad = backend.topleft_pad(target_mask, fast_shape)
643
-
644
- target_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
645
- rfftn(target_pad, target_ft)
646
- target_ft_buffer = backend.arr_to_sharedarr(
647
- arr=target_ft, shared_memory_handler=shared_memory_handler
648
- )
649
-
650
- target_ft2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
651
- rfftn(backend.square(target_pad), target_ft2)
652
- target_ft2_buffer = backend.arr_to_sharedarr(
653
- arr=target_ft2, shared_memory_handler=shared_memory_handler
654
- )
655
-
656
- target_mask_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
657
- rfftn(target_mask_pad, target_mask_ft)
658
- target_mask_ft_buffer = backend.arr_to_sharedarr(
659
- arr=target_mask_ft, shared_memory_handler=shared_memory_handler
660
- )
661
-
662
- template_buffer = backend.arr_to_sharedarr(
663
- arr=template, shared_memory_handler=shared_memory_handler
664
- )
665
- template_mask_buffer = backend.arr_to_sharedarr(
666
- arr=template_mask, shared_memory_handler=shared_memory_handler
667
- )
668
-
669
- template_tuple = (template_buffer, template.shape, template.dtype)
670
- template_mask_tuple = (template_mask_buffer, template.shape, template_mask.dtype)
671
-
672
- target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
673
- target_ft2_tuple = (target_ft2_buffer, fast_ft_shape, complex_dtype)
674
- target_mask_ft_tuple = (target_mask_ft_buffer, fast_ft_shape, complex_dtype)
675
-
676
- ret = {
677
- "template": template_tuple,
678
- "template_mask": template_mask_tuple,
679
- "ft_target": target_ft_tuple,
680
- "ft_target2": target_ft2_tuple,
681
- "ft_target_mask": target_mask_ft_tuple,
682
- "targetshape": target.shape,
683
- "templateshape": template.shape,
684
- "fast_shape": fast_shape,
685
- "fast_ft_shape": fast_ft_shape,
686
- "callback_class": callback_class,
687
- "callback_class_args": callback_class_args,
688
- }
689
-
690
- return ret
691
-
692
-
693
- def corr_scoring(
694
- template: Tuple[type, Tuple[int], type],
695
- ft_target: Tuple[type, Tuple[int], type],
696
- inv_denominator: Tuple[type, Tuple[int], type],
697
- numerator: Tuple[type, Tuple[int], type],
698
- template_filter: Tuple[type, Tuple[int], type],
699
- targetshape: Tuple[int],
700
- templateshape: Tuple[int],
701
- fast_shape: Tuple[int],
702
- fast_ft_shape: Tuple[int],
703
- rotations: NDArray,
704
- callback_class: CallbackClass,
705
- callback_class_args: Dict,
706
- interpolation_order: int,
707
- convolution_mode: str = "full",
708
- **kwargs,
709
- ) -> CallbackClass:
710
- """
711
- Calculates a normalized cross-correlation between a target f and a template g.
712
-
713
- .. math::
714
-
715
- (CC(f,g) - numerator) \\cdot inv\\_denominator
716
-
717
- Parameters
718
- ----------
719
- template : Tuple[type, Tuple[int], type]
720
- Tuple containing a pointer to the template data, its shape, and its datatype.
721
- ft_target : Tuple[type, Tuple[int], type]
722
- Tuple containing a pointer to the fourier tranform of the target,
723
- its shape, and its datatype.
724
- inv_denominator : Tuple[type, Tuple[int], type]
725
- Tuple containing a pointer to the inverse denominator data, its shape, and its
726
- datatype.
727
- numerator : Tuple[type, Tuple[int], type]
728
- Tuple containing a pointer to the numerator data, its shape, and its datatype.
729
- fast_shape : Tuple[int]
730
- The shape for fast Fourier transform.
731
- fast_ft_shape : Tuple[int]
732
- The shape for fast Fourier transform of the target.
733
- rotations : NDArray
734
- Array containing the rotation matrices to be applied on the template.
735
- real_dtype : type
736
- Data type for the real part of the array.
737
- complex_dtype : type
738
- Data type for the complex part of the array.
739
- callback_class : CallbackClass
740
- A callable class or function for processing the results after each
741
- rotation.
742
- callback_class_args : Dict
743
- Dictionary of arguments to be passed to the callback class if it's
744
- instantiable.
745
- interpolation_order : int
746
- The order of interpolation to be used while rotating the template.
747
- **kwargs :
748
- Additional arguments to be passed to the function.
749
-
750
- Returns
751
- -------
752
- CallbackClass
753
- If callback_class was provided an instance of callback_class otherwise None.
754
-
755
- See Also
756
- --------
757
- :py:meth:`cc_setup`
758
- :py:meth:`corr_setup`
759
- :py:meth:`cam_setup`
760
- :py:meth:`flcSphericalMask_setup`
761
- """
762
- real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
763
-
764
- callback = callback_class
765
- if callback_class is not None and isinstance(callback_class, type):
766
- callback = callback_class(**callback_class_args)
767
-
768
- template_buffer, template_shape, template_dtype = template
769
- template = backend.sharedarr_to_arr(template_buffer, template_shape, template_dtype)
770
- ft_target = backend.sharedarr_to_arr(*ft_target)
771
- inv_denominator = backend.sharedarr_to_arr(*inv_denominator)
772
- numerator = backend.sharedarr_to_arr(*numerator)
773
- template_filter = backend.sharedarr_to_arr(*template_filter)
774
-
775
- norm_func = normalize_under_mask
776
- norm_template, template_mask, mask_sum = False, 1, 1
777
- if "template_mask" in kwargs:
778
- norm_template = True
779
- template_mask = backend.sharedarr_to_arr(*kwargs["template_mask"])
780
- mask_sum = backend.sum(template_mask)
781
- if backend.datatype_bytes(template_mask.dtype) == 2:
782
- norm_func = _normalize_under_mask_overflow_safe
783
- mask_sum = backend.sum(
784
- backend.astype(template_mask, backend._overflow_safe_dtype)
785
- )
786
-
787
- norm_template = conditional_execute(norm_func, norm_template)
788
- norm_numerator = (backend.sum(numerator) != 0) & (backend.size(numerator) != 1)
789
- norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
790
-
791
- norm_denominator = (backend.sum(inv_denominator) != 1) & (
792
- backend.size(inv_denominator) != 1
793
- )
794
- norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
795
- callback_func = conditional_execute(callback, callback_class is not None)
796
- template_filter_func = conditional_execute(
797
- apply_filter, backend.size(template_filter) != 1
798
- )
799
-
800
- arr = backend.preallocate_array(fast_shape, real_dtype)
801
- ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
802
-
803
- rfftn, irfftn = backend.build_fft(
804
- fast_shape=fast_shape,
805
- fast_ft_shape=fast_ft_shape,
806
- real_dtype=real_dtype,
807
- complex_dtype=complex_dtype,
808
- fftargs=kwargs.get("fftargs", {}),
809
- temp_real=arr,
810
- temp_fft=ft_temp,
811
- )
812
-
813
- unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
814
- for index in range(rotations.shape[0]):
815
- rotation = rotations[index]
816
- backend.fill(arr, 0)
817
- backend.rotate_array(
818
- arr=template,
819
- rotation_matrix=rotation,
820
- out=arr,
821
- use_geometric_center=True,
822
- order=interpolation_order,
823
- )
824
- norm_template(arr[unpadded_slice], template_mask, mask_sum)
825
-
826
- rfftn(arr, ft_temp)
827
- template_filter_func(ft_template=ft_temp, template_filter=template_filter)
828
- backend.multiply(ft_target, ft_temp, out=ft_temp)
829
- irfftn(ft_temp, arr)
830
-
831
- norm_func_numerator(arr, numerator, out=arr)
832
- norm_func_denominator(arr, inv_denominator, out=arr)
833
-
834
- callback_func(
835
- arr,
836
- rotation_matrix=rotation,
837
- rotation_index=index,
838
- **callback_class_args,
839
- )
840
-
841
- return callback
842
-
843
-
844
- def flc_scoring(
845
- template: Tuple[type, Tuple[int], type],
846
- template_mask: Tuple[type, Tuple[int], type],
847
- ft_target: Tuple[type, Tuple[int], type],
848
- ft_target2: Tuple[type, Tuple[int], type],
849
- template_filter: Tuple[type, Tuple[int], type],
850
- targetshape: Tuple[int],
851
- templateshape: Tuple[int],
852
- fast_shape: Tuple[int],
853
- fast_ft_shape: Tuple[int],
854
- rotations: NDArray,
855
- callback_class: CallbackClass,
856
- callback_class_args: Dict,
857
- interpolation_order: int,
858
- **kwargs,
859
- ) -> CallbackClass:
860
- """
861
- Computes a normalized cross-correlation score of a target f a template g
862
- and a mask m:
863
-
864
- .. math::
865
-
866
- \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
867
- {N_m * \\sqrt{
868
- \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
869
- }
870
-
871
- Where:
872
-
873
- .. math::
874
-
875
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
876
-
877
- and Nm is the number of voxels within the template mask m.
878
-
879
- References
880
- ----------
881
- .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
882
- Microsc. Microanal. 26, 2516 (2020)
883
- .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
884
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
885
- """
886
- real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
887
-
888
- callback = callback_class
889
- if callback_class is not None and isinstance(callback_class, type):
890
- callback = callback_class(**callback_class_args)
891
-
892
- template = backend.sharedarr_to_arr(*template)
893
- template_mask = backend.sharedarr_to_arr(*template_mask)
894
- ft_target = backend.sharedarr_to_arr(*ft_target)
895
- ft_target2 = backend.sharedarr_to_arr(*ft_target2)
896
- template_filter = backend.sharedarr_to_arr(*template_filter)
897
-
898
- arr = backend.preallocate_array(fast_shape, real_dtype)
899
- temp = backend.preallocate_array(fast_shape, real_dtype)
900
- temp2 = backend.preallocate_array(fast_shape, real_dtype)
901
-
902
- ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
903
- ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
904
-
905
- rfftn, irfftn = backend.build_fft(
906
- fast_shape=fast_shape,
907
- fast_ft_shape=fast_ft_shape,
908
- real_dtype=real_dtype,
909
- complex_dtype=complex_dtype,
910
- fftargs=kwargs.get("fftargs", {}),
911
- temp_real=arr,
912
- temp_fft=ft_temp,
913
- )
914
- eps = backend.eps(real_dtype)
915
- template_filter_func = conditional_execute(
916
- apply_filter, backend.size(template_filter) != 1
917
- )
918
- callback_func = conditional_execute(callback, callback_class is not None)
919
-
920
- unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
921
- for index in range(rotations.shape[0]):
922
- rotation = rotations[index]
923
- backend.fill(arr, 0)
924
- backend.fill(temp, 0)
925
- backend.rotate_array(
926
- arr=template,
927
- arr_mask=template_mask,
928
- rotation_matrix=rotation,
929
- out=arr,
930
- out_mask=temp,
931
- use_geometric_center=True,
932
- order=interpolation_order,
933
- )
934
- # Given the amount of FFTs, might aswell normalize properly
935
- n_observations = backend.sum(temp)
936
-
937
- normalize_under_mask(
938
- template=arr[unpadded_slice],
939
- mask=temp[unpadded_slice],
940
- mask_intensity=n_observations,
941
- )
942
-
943
- rfftn(temp, ft_temp)
944
-
945
- backend.multiply(ft_target, ft_temp, out=ft_denom)
946
- irfftn(ft_denom, temp)
947
- backend.divide(temp, n_observations, out=temp)
948
- backend.square(temp, out=temp)
949
-
950
- backend.multiply(ft_target2, ft_temp, out=ft_denom)
951
- irfftn(ft_denom, temp2)
952
- backend.divide(temp2, n_observations, out=temp2)
953
-
954
- backend.subtract(temp2, temp, out=temp)
955
- backend.maximum(temp, 0.0, out=temp)
956
- backend.sqrt(temp, out=temp)
957
- backend.multiply(temp, n_observations, out=temp)
958
-
959
- rfftn(arr, ft_temp)
960
- template_filter_func(ft_template=ft_temp, template_filter=template_filter)
961
- backend.multiply(ft_target, ft_temp, out=ft_temp)
962
- irfftn(ft_temp, arr)
963
-
964
- tol = eps
965
- # tol = 1e3 * eps * backend.max(backend.abs(temp))
966
- nonzero_indices = temp > tol
967
- backend.fill(temp2, 0)
968
- temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
969
-
970
- callback_func(
971
- temp2,
972
- rotation_matrix=rotation,
973
- rotation_index=index,
974
- **callback_class_args,
975
- )
976
-
977
- return callback
978
-
979
-
980
- def flc_scoring2(
981
- template: Tuple[type, Tuple[int], type],
982
- template_mask: Tuple[type, Tuple[int], type],
983
- ft_target: Tuple[type, Tuple[int], type],
984
- ft_target2: Tuple[type, Tuple[int], type],
985
- template_filter: Tuple[type, Tuple[int], type],
986
- targetshape: Tuple[int],
987
- templateshape: Tuple[int],
988
- fast_shape: Tuple[int],
989
- fast_ft_shape: Tuple[int],
990
- rotations: NDArray,
991
- callback_class: CallbackClass,
992
- callback_class_args: Dict,
993
- interpolation_order: int,
994
- **kwargs,
995
- ) -> CallbackClass:
996
- """
997
- Computes a normalized cross-correlation score of a target f a template g
998
- and a mask m:
999
-
1000
- .. math::
1001
-
1002
- \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
1003
- {N_m * \\sqrt{
1004
- \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
1005
- }
1006
-
1007
- Where:
1008
-
1009
- .. math::
1010
-
1011
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
1012
-
1013
- and Nm is the number of voxels within the template mask m.
1014
-
1015
- References
1016
- ----------
1017
- .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
1018
- Microsc. Microanal. 26, 2516 (2020)
1019
- .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
1020
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
1021
- """
1022
- real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
1023
- callback = callback_class
1024
- if callback_class is not None and isinstance(callback_class, type):
1025
- callback = callback_class(**callback_class_args)
1026
-
1027
- # Retrieve objects from shared memory
1028
- template = backend.sharedarr_to_arr(*template)
1029
- template_mask = backend.sharedarr_to_arr(*template_mask)
1030
- ft_target = backend.sharedarr_to_arr(*ft_target)
1031
- ft_target2 = backend.sharedarr_to_arr(*ft_target2)
1032
- template_filter = backend.sharedarr_to_arr(*template_filter)
1033
-
1034
- arr = backend.preallocate_array(fast_shape, real_dtype)
1035
- temp = backend.preallocate_array(fast_shape, real_dtype)
1036
- temp2 = backend.preallocate_array(fast_shape, real_dtype)
1037
-
1038
- ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
1039
- ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
1040
-
1041
- eps = backend.eps(real_dtype)
1042
- template_filter_func = conditional_execute(
1043
- apply_filter, backend.size(template_filter) != 1
1044
- )
1045
- callback_func = conditional_execute(callback, callback_class is not None)
1046
-
1047
- squeeze_axis = tuple(i for i, x in enumerate(template.shape) if x == 1)
1048
- squeeze = tuple(
1049
- slice(0, stop) if i not in squeeze_axis else 0
1050
- for i, stop in enumerate(template.shape)
1051
- )
1052
- squeeze_fast = tuple(
1053
- slice(0, stop) if i not in squeeze_axis else 0
1054
- for i, stop in enumerate(fast_shape)
1055
- )
1056
- squeeze_fast_ft = tuple(
1057
- slice(0, stop) if i not in squeeze_axis else 0
1058
- for i, stop in enumerate(fast_ft_shape)
1059
- )
1060
-
1061
- rfftn, irfftn = backend.build_fft(
1062
- fast_shape=temp[squeeze_fast].shape,
1063
- fast_ft_shape=fast_ft_shape,
1064
- real_dtype=real_dtype,
1065
- complex_dtype=complex_dtype,
1066
- fftargs=kwargs.get("fftargs", {}),
1067
- inverse_fast_shape=fast_shape,
1068
- temp_real=arr[squeeze_fast],
1069
- temp_fft=ft_temp,
1070
- )
1071
- for index in range(rotations.shape[0]):
1072
- rotation = rotations[index]
1073
- backend.fill(arr, 0)
1074
- backend.fill(temp, 0)
1075
- backend.rotate_array(
1076
- arr=template[squeeze],
1077
- arr_mask=template_mask[squeeze],
1078
- rotation_matrix=rotation,
1079
- out=arr[squeeze],
1080
- out_mask=temp[squeeze],
1081
- use_geometric_center=True,
1082
- order=interpolation_order,
1083
- )
1084
- # Given the amount of FFTs, might aswell normalize properly
1085
- n_observations = backend.sum(temp)
1086
-
1087
- normalize_under_mask(
1088
- template=arr[squeeze],
1089
- mask=temp[squeeze],
1090
- mask_intensity=n_observations,
1091
- )
1092
- rfftn(temp[squeeze_fast], ft_temp[squeeze_fast_ft])
1093
-
1094
- backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
1095
- irfftn(ft_denom, temp)
1096
- backend.divide(temp, n_observations, out=temp)
1097
- backend.square(temp, out=temp)
1098
-
1099
- backend.multiply(ft_target2, ft_temp[squeeze_fast_ft], out=ft_denom)
1100
- irfftn(ft_denom, temp2)
1101
- backend.divide(temp2, n_observations, out=temp2)
1102
-
1103
- backend.subtract(temp2, temp, out=temp)
1104
- backend.maximum(temp, 0.0, out=temp)
1105
- backend.sqrt(temp, out=temp)
1106
- backend.multiply(temp, n_observations, out=temp)
1107
-
1108
- rfftn(arr[squeeze_fast], ft_temp[squeeze_fast_ft])
1109
- template_filter_func(ft_template=ft_temp, template_filter=template_filter)
1110
- backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
1111
- irfftn(ft_denom, arr)
1112
-
1113
- nonzero_indices = temp > eps
1114
- backend.fill(temp2, 0)
1115
- temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
1116
-
1117
- callback_func(
1118
- temp2,
1119
- rotation_matrix=rotation,
1120
- rotation_index=index,
1121
- **callback_class_args,
1122
- )
1123
- return callback
1124
-
1125
-
1126
- def mcc_scoring(
1127
- template: Tuple[type, Tuple[int], type],
1128
- template_mask: Tuple[type, Tuple[int], type],
1129
- ft_target: Tuple[type, Tuple[int], type],
1130
- ft_target2: Tuple[type, Tuple[int], type],
1131
- ft_target_mask: Tuple[type, Tuple[int], type],
1132
- template_filter: Tuple[type, Tuple[int], type],
1133
- targetshape: Tuple[int],
1134
- templateshape: Tuple[int],
1135
- fast_shape: Tuple[int],
1136
- fast_ft_shape: Tuple[int],
1137
- rotations: NDArray,
1138
- callback_class: CallbackClass,
1139
- callback_class_args: type,
1140
- interpolation_order: int,
1141
- overlap_ratio: float = 0.3,
1142
- **kwargs,
1143
- ) -> CallbackClass:
1144
- """
1145
- Computes a cross-correlation score with masks for both template and target.
1146
-
1147
- .. math::
1148
-
1149
- \\frac{
1150
- CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
1151
- }{
1152
- \\sqrt{
1153
- (CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
1154
- (CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
1155
- }
1156
- }
1157
-
1158
- Where:
1159
-
1160
- .. math::
1161
-
1162
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
1163
-
1164
-
1165
- References
1166
- ----------
1167
- .. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
1168
- .. [2] https://scikit-image.org/docs/stable/api/skimage.registration.html
1169
-
1170
- See Also
1171
- --------
1172
- :py:class:`tme.matching_optimization.MaskedCrossCorrelation`
38
+ Raises
39
+ ------
40
+ Exception
41
+ Re-raises the last exception.
1173
42
  """
1174
- real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
1175
- callback = callback_class
1176
- if callback_class is not None and isinstance(callback_class, type):
1177
- callback = callback_class(**callback_class_args)
1178
-
1179
- # Retrieve objects from shared memory
1180
- template_buffer, template_shape, template_dtype = template
1181
- template = backend.sharedarr_to_arr(*template)
1182
- target_ft = backend.sharedarr_to_arr(*ft_target)
1183
- target_ft2 = backend.sharedarr_to_arr(*ft_target2)
1184
- template_mask = backend.sharedarr_to_arr(*template_mask)
1185
- target_mask_ft = backend.sharedarr_to_arr(*ft_target_mask)
1186
- template_filter = backend.sharedarr_to_arr(*template_filter)
1187
-
1188
- axes = tuple(range(template.ndim))
1189
- eps = backend.eps(real_dtype)
1190
-
1191
- # Allocate score and process specific arrays
1192
- template_rot = backend.preallocate_array(fast_shape, real_dtype)
1193
- mask_overlap = backend.preallocate_array(fast_shape, real_dtype)
1194
- numerator = backend.preallocate_array(fast_shape, real_dtype)
1195
-
1196
- temp = backend.preallocate_array(fast_shape, real_dtype)
1197
- temp2 = backend.preallocate_array(fast_shape, real_dtype)
1198
- temp3 = backend.preallocate_array(fast_shape, real_dtype)
1199
- temp_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
1200
-
1201
- rfftn, irfftn = backend.build_fft(
1202
- fast_shape=fast_shape,
1203
- fast_ft_shape=fast_ft_shape,
1204
- real_dtype=real_dtype,
1205
- complex_dtype=complex_dtype,
1206
- fftargs=kwargs.get("fftargs", {}),
1207
- temp_real=numerator,
1208
- temp_fft=temp_ft,
1209
- )
1210
-
1211
- template_filter_func = conditional_execute(
1212
- apply_filter, backend.size(template_filter) != 1
1213
- )
1214
- callback_func = conditional_execute(callback, callback_class is not None)
1215
-
1216
- # Calculate scores across all rotations
1217
- for index in range(rotations.shape[0]):
1218
- rotation = rotations[index]
1219
- backend.fill(template_rot, 0)
1220
- backend.fill(temp, 0)
1221
-
1222
- backend.rotate_array(
1223
- arr=template,
1224
- arr_mask=template_mask,
1225
- rotation_matrix=rotation,
1226
- out=template_rot,
1227
- out_mask=temp,
1228
- use_geometric_center=True,
1229
- order=interpolation_order,
1230
- )
1231
-
1232
- backend.multiply(template_rot, temp > 0, out=template_rot)
1233
-
1234
- # template_rot_ft
1235
- rfftn(template_rot, temp_ft)
1236
- template_filter_func(ft_template=temp_ft, template_filter=template_filter)
1237
- irfftn(target_mask_ft * temp_ft, temp2)
1238
- irfftn(target_ft * temp_ft, numerator)
1239
-
1240
- # temp template_mask_rot | temp_ft template_mask_rot_ft
1241
- # Calculate overlap of masks at every point in the convolution.
1242
- # Locations with high overlap should not be taken into account.
1243
- rfftn(temp, temp_ft)
1244
- irfftn(temp_ft * target_mask_ft, mask_overlap)
1245
- mask_overlap[:] = np.round(mask_overlap)
1246
- mask_overlap[:] = np.maximum(mask_overlap, eps)
1247
- irfftn(temp_ft * target_ft, temp)
1248
-
1249
- backend.subtract(
1250
- numerator,
1251
- backend.divide(backend.multiply(temp, temp2), mask_overlap),
1252
- out=numerator,
1253
- )
1254
-
1255
- # temp_3 = fixed_denom
1256
- backend.multiply(temp_ft, target_ft2, out=temp_ft)
1257
- irfftn(temp_ft, temp3)
1258
- backend.subtract(
1259
- temp3, backend.divide(backend.square(temp), mask_overlap), out=temp3
1260
- )
1261
- backend.maximum(temp3, 0.0, out=temp3)
43
+ if last_type is None:
44
+ return None
45
+ traceback.print_tb(last_traceback)
46
+ raise Exception(last_value)
1262
47
 
1263
- # temp = moving_denom
1264
- rfftn(backend.square(template_rot), temp_ft)
1265
- backend.multiply(target_mask_ft, temp_ft, out=temp_ft)
1266
- irfftn(temp_ft, temp)
1267
48
 
1268
- backend.subtract(
1269
- temp, backend.divide(backend.square(temp2), mask_overlap), out=temp
1270
- )
1271
- backend.maximum(temp, 0.0, out=temp)
1272
-
1273
- # temp_2 = denom
1274
- backend.multiply(temp3, temp, out=temp)
1275
- backend.sqrt(temp, temp2)
1276
-
1277
- # Pixels where `denom` is very small will introduce large
1278
- # numbers after division. To get around this problem,
1279
- # we zero-out problematic pixels.
1280
- tol = 1e3 * eps * backend.max(backend.abs(temp2), axis=axes, keepdims=True)
1281
- nonzero_indices = temp2 > tol
1282
-
1283
- backend.fill(temp, 0)
1284
- temp[nonzero_indices] = numerator[nonzero_indices] / temp2[nonzero_indices]
1285
- backend.clip(temp, a_min=-1, a_max=1, out=temp)
1286
-
1287
- # Apply overlap ratio threshold
1288
- number_px_threshold = overlap_ratio * backend.max(
1289
- mask_overlap, axis=axes, keepdims=True
1290
- )
1291
- temp[mask_overlap < number_px_threshold] = 0.0
49
+ def _wrap_backend(func):
50
+ @wraps(func)
51
+ def wrapper(*args, backend_name: str, backend_args: Dict, **kwargs):
52
+ from tme.backends import backend as be
1292
53
 
1293
- callback_func(
1294
- temp,
1295
- rotation_matrix=rotation,
1296
- rotation_index=index,
1297
- **callback_class_args,
1298
- )
54
+ be.change_backend(backend_name, **backend_args)
55
+ return func(*args, **kwargs)
1299
56
 
1300
- return callback
57
+ return wrapper
1301
58
 
1302
59
 
1303
60
  def _setup_template_filter_apply_target_filter(
@@ -1306,42 +63,47 @@ def _setup_template_filter_apply_target_filter(
1306
63
  irfftn: Callable,
1307
64
  fast_shape: Tuple[int],
1308
65
  fast_ft_shape: Tuple[int],
66
+ pad_template_filter: bool = True,
1309
67
  ):
1310
68
  filter_template = isinstance(matching_data.template_filter, Compose)
1311
69
  filter_target = isinstance(matching_data.target_filter, Compose)
1312
70
 
1313
- template_filter = backend.full(shape=(1,), fill_value=1, dtype=backend._float_dtype)
71
+ template_filter = be.full(shape=(1,), fill_value=1, dtype=be._float_dtype)
1314
72
 
1315
73
  if not filter_template and not filter_target:
1316
74
  return template_filter
1317
75
 
1318
- target_temp = backend.astype(
1319
- backend.topleft_pad(matching_data.target, fast_shape), backend._float_dtype
1320
- )
1321
- target_temp_ft = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
76
+ target_temp = be.topleft_pad(matching_data.target, fast_shape)
77
+ target_temp_ft = be.zeros(fast_ft_shape, be._complex_dtype)
1322
78
 
1323
- filter_shape = backend.multiply(fast_ft_shape, 1 - matching_data._batch_mask)
1324
- filter_shape[filter_shape == 0] = 1
1325
- fast_shape = backend.multiply(fast_shape, 1 - matching_data._batch_mask)
1326
- fast_shape = fast_shape[fast_shape != 0]
79
+ inv_mask = be.subtract(1, be.to_backend_array(matching_data._batch_mask))
80
+ filter_shape = be.multiply(be.to_backend_array(fast_ft_shape), inv_mask)
81
+ filter_shape = tuple(int(x) if x != 0 else 1 for x in filter_shape)
1327
82
 
1328
- fast_shape = tuple(int(x) for x in fast_shape)
1329
- filter_shape = tuple(int(x) for x in filter_shape)
83
+ fast_shape = be.multiply(be.to_backend_array(fast_shape), inv_mask)
84
+ fast_shape = tuple(int(x) for x in fast_shape if x != 0)
85
+
86
+ target_temp_ft = rfftn(target_temp, target_temp_ft)
87
+ if filter_template:
88
+ # TODO: Pad to fast shapes and adapt _setup_template_filtering accordingly
89
+ template_fast_shape, template_filter_shape = fast_shape, filter_shape
90
+ if not pad_template_filter:
91
+ template_fast_shape = tuple(int(x) for x in matching_data._template.shape)
92
+ template_filter_shape = [
93
+ int(x) for x in matching_data._output_template_shape
94
+ ]
95
+ template_filter_shape[-1] = template_filter_shape[-1] // 2 + 1
1330
96
 
1331
- rfftn(target_temp, target_temp_ft)
1332
- if isinstance(matching_data.template_filter, Compose):
1333
97
  template_filter = matching_data.template_filter(
1334
- shape=fast_shape,
98
+ shape=template_fast_shape,
1335
99
  return_real_fourier=True,
1336
100
  shape_is_real_fourier=False,
1337
101
  data_rfft=target_temp_ft,
1338
102
  batch_dimension=matching_data._target_dims,
1339
- )
1340
- template_filter = template_filter["data"]
1341
- template_filter[tuple(0 for _ in range(template_filter.ndim))] = 0
1342
- template_filter = backend.reshape(template_filter, filter_shape)
103
+ )["data"]
104
+ template_filter = be.reshape(template_filter, template_filter_shape)
1343
105
 
1344
- if isinstance(matching_data.target_filter, Compose):
106
+ if filter_target:
1345
107
  target_filter = matching_data.target_filter(
1346
108
  shape=fast_shape,
1347
109
  return_real_fourier=True,
@@ -1349,15 +111,12 @@ def _setup_template_filter_apply_target_filter(
1349
111
  data_rfft=target_temp_ft,
1350
112
  weight_type=None,
1351
113
  batch_dimension=matching_data._target_dims,
1352
- )
1353
- target_filter = target_filter["data"]
1354
- target_filter = backend.reshape(target_filter, filter_shape)
1355
- backend.multiply(target_temp_ft, target_filter, out=target_temp_ft)
114
+ )["data"]
115
+ target_filter = be.reshape(target_filter, filter_shape)
116
+ target_temp_ft = be.multiply(target_temp_ft, target_filter, out=target_temp_ft)
1356
117
 
1357
- irfftn(target_temp_ft, target_temp)
1358
- matching_data._target = backend.topleft_pad(
1359
- target_temp, matching_data.target.shape
1360
- )
118
+ target_temp = irfftn(target_temp_ft, target_temp)
119
+ matching_data._target = be.topleft_pad(target_temp, matching_data.target.shape)
1361
120
 
1362
121
  return template_filter
1363
122
 
@@ -1371,13 +130,14 @@ def device_memory_handler(func: Callable):
1371
130
  last_type, last_value, last_traceback = sys.exc_info()
1372
131
  try:
1373
132
  with SharedMemoryManager() as smh:
1374
- with backend.set_device(kwargs.get("gpu_index", 0)):
133
+ gpu_index = kwargs.pop("gpu_index") if "gpu_index" in kwargs else 0
134
+ with be.set_device(gpu_index):
1375
135
  return_value = func(shared_memory_handler=smh, *args, **kwargs)
1376
136
  except Exception as e:
1377
137
  print(e)
1378
138
  last_type, last_value, last_traceback = sys.exc_info()
1379
139
  finally:
1380
- handle_traceback(last_type, last_value, last_traceback)
140
+ _handle_traceback(last_type, last_value, last_traceback)
1381
141
  return return_value
1382
142
 
1383
143
  return inner_function
@@ -1391,18 +151,20 @@ def scan(
1391
151
  n_jobs: int = 4,
1392
152
  callback_class: CallbackClass = None,
1393
153
  callback_class_args: Dict = {},
1394
- fftargs: Dict = {},
1395
154
  pad_fourier: bool = True,
155
+ pad_template_filter: bool = True,
1396
156
  interpolation_order: int = 3,
1397
157
  jobs_per_callback_class: int = 8,
1398
- **kwargs,
1399
- ) -> Tuple:
158
+ shared_memory_handler=None,
159
+ ) -> Optional[Tuple]:
1400
160
  """
1401
- Perform template matching.
161
+ Run template matching.
162
+
163
+ .. warning:: ``matching_data`` might be altered or destroyed during computation.
1402
164
 
1403
165
  Parameters
1404
166
  ----------
1405
- matching_data : MatchingData
167
+ matching_data : :py:class:`tme.matching_data.MatchingData`
1406
168
  Template matching data.
1407
169
  matching_setup : Callable
1408
170
  Function pointer to setup function.
@@ -1414,21 +176,21 @@ def scan(
1414
176
  Analyzer class pointer to operate on computed scores.
1415
177
  callback_class_args : dict, optional
1416
178
  Arguments passed to the callback_class. Default is an empty dictionary.
1417
- fftargs : dict, optional
1418
- Arguments for the FFT operations. Default is an empty dictionary.
1419
179
  pad_fourier: bool, optional
1420
180
  Whether to pad target and template to the full convolution shape.
181
+ pad_template_filter: bool, optional
182
+ Whether to pad potential template filters to the full convolution shape.
1421
183
  interpolation_order : int, optional
1422
184
  Order of spline interpolation for rotations.
1423
185
  jobs_per_callback_class : int, optional
1424
186
  How many jobs should be processed by a single callback_class instance,
1425
187
  if one is provided.
1426
- **kwargs : various
1427
- Additional keyword arguments.
188
+ shared_memory_handler : type, optional
189
+ Manager for shared memory objects, None by default.
1428
190
 
1429
191
  Returns
1430
192
  -------
1431
- Tuple
193
+ Optional[Tuple]
1432
194
  The merged results from callback_class if provided otherwise None.
1433
195
 
1434
196
  Examples
@@ -1450,159 +212,95 @@ def scan(
1450
212
 
1451
213
  """
1452
214
  matching_data.to_backend()
1453
- shape_diff = backend.subtract(
1454
- matching_data._output_target_shape, matching_data._output_template_shape
1455
- )
1456
- shape_diff = backend.multiply(shape_diff, 1 - matching_data._batch_mask)
1457
-
1458
- if backend.sum(shape_diff < 0) and not pad_fourier:
1459
- warnings.warn(
1460
- "Target is larger than template and Fourier padding is turned off."
1461
- " This can lead to shifted results. You can swap template and target,"
1462
- " zero-pad the target or turn off template centering."
1463
- )
1464
215
  fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
1465
216
  pad_fourier=pad_fourier
1466
217
  )
1467
218
 
1468
- callback_class_args["fourier_shift"] = fourier_shift
1469
- rfftn, irfftn = backend.build_fft(
219
+ rfftn, irfftn = be.build_fft(
1470
220
  fast_shape=fast_shape,
1471
221
  fast_ft_shape=fast_ft_shape,
1472
- real_dtype=backend._float_dtype,
1473
- complex_dtype=backend._complex_dtype,
1474
- fftargs=fftargs,
222
+ real_dtype=be._float_dtype,
223
+ complex_dtype=be._complex_dtype,
1475
224
  )
1476
-
1477
225
  template_filter = _setup_template_filter_apply_target_filter(
1478
226
  matching_data=matching_data,
1479
227
  rfftn=rfftn,
1480
228
  irfftn=irfftn,
1481
229
  fast_shape=fast_shape,
1482
230
  fast_ft_shape=fast_ft_shape,
231
+ pad_template_filter=pad_template_filter,
1483
232
  )
233
+ template_filter = be.astype(be.to_backend_array(template_filter), be._float_dtype)
1484
234
 
1485
235
  setup = matching_setup(
1486
236
  rfftn=rfftn,
1487
237
  irfftn=irfftn,
1488
238
  template=matching_data.template,
239
+ template_filter=template_filter,
1489
240
  template_mask=matching_data.template_mask,
1490
241
  target=matching_data.target,
1491
242
  target_mask=matching_data.target_mask,
1492
243
  fast_shape=fast_shape,
1493
244
  fast_ft_shape=fast_ft_shape,
1494
- callback_class=callback_class,
1495
- callback_class_args=callback_class_args,
1496
- **kwargs,
245
+ shared_memory_handler=shared_memory_handler,
1497
246
  )
1498
247
  rfftn, irfftn = None, None
1499
-
1500
- template_filter = backend.to_backend_array(template_filter)
1501
- template_filter = backend.astype(template_filter, backend._float_dtype)
1502
- template_filter_buffer = backend.arr_to_sharedarr(
1503
- arr=template_filter,
1504
- shared_memory_handler=kwargs.get("shared_memory_handler", None),
1505
- )
1506
- setup["template_filter"] = (
1507
- template_filter_buffer,
1508
- template_filter.shape,
1509
- template_filter.dtype,
1510
- )
1511
-
1512
- callback_class_args["translation_offset"] = backend.astype(
1513
- matching_data._translation_offset, int
1514
- )
1515
- callback_class_args["thread_safe"] = n_jobs > 1
1516
- callback_class_args["gpu_index"] = kwargs.get("gpu_index", -1)
1517
-
1518
- n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
1519
- callback_class = setup.pop("callback_class", callback_class)
1520
- callback_class_args = setup.pop("callback_class_args", callback_class_args)
1521
- callback_classes = [callback_class for _ in range(n_callback_classes)]
1522
-
1523
- convolution_mode = "same"
1524
- if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
1525
- convolution_mode = "valid"
1526
-
1527
- callback_class_args["fourier_shift"] = fourier_shift
1528
- callback_class_args["convolution_mode"] = convolution_mode
1529
- callback_class_args["targetshape"] = setup["targetshape"]
1530
- callback_class_args["templateshape"] = setup["templateshape"]
1531
-
1532
- if callback_class == MaxScoreOverRotations:
1533
- callback_classes = [
1534
- class_name(
1535
- score_space_shape=fast_shape,
1536
- score_space_dtype=backend._float_dtype,
1537
- shared_memory_handler=kwargs.get("shared_memory_handler", None),
1538
- rotation_space_dtype=backend._int_dtype,
1539
- **callback_class_args,
1540
- )
1541
- for class_name in callback_classes
1542
- ]
1543
-
1544
- matching_data._target, matching_data._template = None, None
1545
- matching_data._target_mask, matching_data._template_mask = None, None
1546
-
1547
- setup["fftargs"] = fftargs.copy()
1548
- convolution_mode = "same"
1549
- if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
1550
- convolution_mode = "valid"
1551
- setup["convolution_mode"] = convolution_mode
1552
248
  setup["interpolation_order"] = interpolation_order
1553
- rotation_list = matching_data._split_rotations_on_jobs(n_jobs)
1554
-
1555
- backend.free_cache()
1556
-
1557
- def _run_scoring(backend_name, backend_args, rotations, **kwargs):
1558
- from tme.backends import backend
249
+ setup["template_filter"] = be.to_sharedarr(template_filter, shared_memory_handler)
250
+
251
+ offset = be.to_backend_array(matching_data._translation_offset)
252
+ convmode = "valid" if getattr(matching_data, "_is_padded", False) else "same"
253
+ default_callback_args = {
254
+ "offset": be.astype(offset, be._int_dtype),
255
+ "thread_safe": n_jobs > 1,
256
+ "fourier_shift": fourier_shift,
257
+ "convolution_mode": convmode,
258
+ "targetshape": matching_data.target.shape,
259
+ "templateshape": matching_data.template.shape,
260
+ "fast_shape": fast_shape,
261
+ "indices": getattr(matching_data, "indices", None),
262
+ "shared_memory_handler": shared_memory_handler,
263
+ "only_unique_rotations": True,
264
+ }
265
+ default_callback_args.update(callback_class_args)
1559
266
 
1560
- backend.change_backend(backend_name, **backend_args)
1561
- return matching_score(rotations=rotations, **kwargs)
267
+ matching_data._free_data()
268
+ be.free_cache()
1562
269
 
270
+ # For highly parallel jobs, blocking in certain analyzers becomes a bottleneck
271
+ if getattr(callback_class, "shared", True):
272
+ jobs_per_callback_class = 1
273
+ n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
274
+ callback_classes = [
275
+ callback_class(
276
+ shape=fast_shape,
277
+ **default_callback_args,
278
+ )
279
+ if callback_class is not None
280
+ else None
281
+ for _ in range(n_callback_classes)
282
+ ]
1563
283
  callbacks = Parallel(n_jobs=n_jobs)(
1564
- delayed(_run_scoring)(
1565
- backend_name=backend._backend_name,
1566
- backend_args=backend._backend_args,
284
+ delayed(_wrap_backend(matching_score))(
285
+ backend_name=be._backend_name,
286
+ backend_args=be._backend_args,
1567
287
  rotations=rotation,
1568
- callback_class=callback_classes[index % n_callback_classes],
1569
- callback_class_args=callback_class_args,
288
+ callback=callback_classes[index % n_callback_classes],
1570
289
  **setup,
1571
290
  )
1572
- for index, rotation in enumerate(rotation_list)
291
+ for index, rotation in enumerate(matching_data._split_rotations_on_jobs(n_jobs))
1573
292
  )
1574
293
 
1575
- callbacks = callbacks[0:n_callback_classes]
1576
294
  callbacks = [
1577
- tuple(
1578
- callback._postprocess(
1579
- fourier_shift=fourier_shift,
1580
- convolution_mode=convolution_mode,
1581
- targetshape=setup["targetshape"],
1582
- templateshape=setup["templateshape"],
1583
- shared_memory_handler=kwargs.get("shared_memory_handler", None),
1584
- )
1585
- )
1586
- if hasattr(callback, "_postprocess")
1587
- else tuple(callback)
295
+ tuple(callback._postprocess(**default_callback_args))
1588
296
  for callback in callbacks
1589
297
  if callback is not None
1590
298
  ]
1591
- backend.free_cache()
299
+ be.free_cache()
1592
300
 
1593
- merged_callback = None
1594
301
  if callback_class is not None:
1595
- score_indices = None
1596
- if hasattr(matching_data, "indices"):
1597
- score_indices = matching_data.indices
1598
- merged_callback = callback_class.merge(
1599
- callbacks,
1600
- **callback_class_args,
1601
- score_indices=score_indices,
1602
- inner_merge=True,
1603
- )
1604
-
1605
- return merged_callback
302
+ return callback_class.merge(callbacks, **default_callback_args)
303
+ return None
1606
304
 
1607
305
 
1608
306
  def scan_subsets(
@@ -1616,10 +314,12 @@ def scan_subsets(
1616
314
  template_splits: Dict = {},
1617
315
  pad_target_edges: bool = False,
1618
316
  pad_fourier: bool = True,
317
+ pad_template_filter: bool = True,
1619
318
  interpolation_order: int = 3,
1620
319
  jobs_per_callback_class: int = 8,
1621
- **kwargs,
1622
- ) -> Tuple:
320
+ backend_name: str = None,
321
+ backend_args: Dict = {},
322
+ ) -> Optional[Tuple]:
1623
323
  """
1624
324
  Wrapper around :py:meth:`scan` that supports matching on splits
1625
325
  of ``matching_data``.
@@ -1651,21 +351,17 @@ def scan_subsets(
1651
351
  along each axis.
1652
352
  pad_fourier: bool, optional
1653
353
  Whether to pad target and template to the full convolution shape.
354
+ pad_template_filter: bool, optional
355
+ Whether to pad potential template filters to the full convolution shape.
1654
356
  interpolation_order : int, optional
1655
357
  Order of spline interpolation for rotations.
1656
358
  jobs_per_callback_class : int, optional
1657
359
  How many jobs should be processed by a single callback_class instance,
1658
360
  if ones is provided.
1659
- **kwargs : various
1660
- Additional arguments.
1661
-
1662
- Notes
1663
- -----
1664
- Objects in matching_data might be destroyed during computation.
1665
361
 
1666
362
  Returns
1667
363
  -------
1668
- Tuple
364
+ Optional[Tuple]
1669
365
  The merged results from callback_class if provided otherwise None.
1670
366
 
1671
367
  Examples
@@ -1720,73 +416,59 @@ def scan_subsets(
1720
416
  >>> target_splits = target_splits,
1721
417
  >>> )
1722
418
 
1723
- The retuned ``results`` tuple contains the output of the chosen analyzer.
419
+ The ``results`` tuple contains the output of the chosen analyzer.
1724
420
 
1725
421
  See Also
1726
422
  --------
1727
423
  :py:meth:`tme.matching_utils.compute_parallelization_schedule`
1728
424
  """
1729
- target_splits = split_numpy_array_slices(
1730
- matching_data._target.shape, splits=target_splits
1731
- )
425
+ template_splits = split_shape(matching_data._template.shape, splits=template_splits)
426
+ target_splits = split_shape(matching_data._target.shape, splits=target_splits)
1732
427
  if (len(target_splits) > 1) and not pad_target_edges:
1733
428
  warnings.warn(
1734
429
  "Target splitting without padding target edges leads to unreliable "
1735
430
  "similarity estimates around the split border."
1736
431
  )
1737
-
1738
- template_splits = split_numpy_array_slices(
1739
- matching_data._template.shape, splits=template_splits
1740
- )
1741
- target_pad = matching_data.target_padding(pad_target=pad_target_edges)
432
+ splits = tuple(product(target_splits, template_splits))
1742
433
 
1743
434
  outer_jobs, inner_jobs = job_schedule
1744
- results = Parallel(n_jobs=outer_jobs)(
1745
- delayed(_run_inner)(
1746
- backend_name=backend._backend_name,
1747
- backend_args=backend._backend_args,
1748
- matching_data=matching_data.subset_by_slice(
1749
- target_slice=target_split,
1750
- target_pad=target_pad,
1751
- template_slice=template_split,
1752
- ),
1753
- matching_score=matching_score,
1754
- matching_setup=matching_setup,
1755
- n_jobs=inner_jobs,
435
+ target_pad = matching_data.target_padding(pad_target=pad_target_edges)
436
+ if hasattr(be, "scan"):
437
+ corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
438
+ results = be.scan(
439
+ matching_data=matching_data,
440
+ splits=splits,
441
+ n_jobs=outer_jobs,
442
+ rotate_mask=matching_score != corr_scoring,
1756
443
  callback_class=callback_class,
1757
- callback_class_args=callback_class_args,
1758
- interpolation_order=interpolation_order,
1759
- pad_fourier=pad_fourier,
1760
- gpu_index=index % outer_jobs,
1761
- **kwargs,
1762
444
  )
1763
- for index, (target_split, template_split) in enumerate(
1764
- product(target_splits, template_splits)
445
+ else:
446
+ results = Parallel(n_jobs=outer_jobs)(
447
+ delayed(_wrap_backend(scan))(
448
+ backend_name=be._backend_name,
449
+ backend_args=be._backend_args,
450
+ matching_data=matching_data.subset_by_slice(
451
+ target_slice=target_split,
452
+ target_pad=target_pad,
453
+ template_slice=template_split,
454
+ ),
455
+ matching_score=matching_score,
456
+ matching_setup=matching_setup,
457
+ n_jobs=inner_jobs,
458
+ callback_class=callback_class,
459
+ callback_class_args=callback_class_args,
460
+ interpolation_order=interpolation_order,
461
+ pad_fourier=pad_fourier,
462
+ gpu_index=index % outer_jobs,
463
+ pad_template_filter=pad_template_filter,
464
+ )
465
+ for index, (target_split, template_split) in enumerate(splits)
1765
466
  )
1766
- )
1767
-
1768
- matching_data._target, matching_data._template = None, None
1769
- matching_data._target_mask, matching_data._template_mask = None, None
1770
467
 
1771
- candidates = None
468
+ matching_data._free_data()
1772
469
  if callback_class is not None:
1773
- candidates = callback_class.merge(
1774
- results, **callback_class_args, inner_merge=False
1775
- )
1776
-
1777
- return candidates
1778
-
1779
-
1780
- MATCHING_EXHAUSTIVE_REGISTER = {
1781
- "CC": (cc_setup, corr_scoring),
1782
- "LCC": (lcc_setup, corr_scoring),
1783
- "CORR": (corr_setup, corr_scoring),
1784
- "CAM": (cam_setup, corr_scoring),
1785
- "FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
1786
- "FLC": (flc_setup, flc_scoring),
1787
- "FLC2": (flc_setup, flc_scoring2),
1788
- "MCC": (mcc_setup, mcc_scoring),
1789
- }
470
+ return callback_class.merge(results, **callback_class_args)
471
+ return None
1790
472
 
1791
473
 
1792
474
  def register_matching_exhaustive(
@@ -1803,20 +485,17 @@ def register_matching_exhaustive(
1803
485
  matching : str
1804
486
  Name of the matching method.
1805
487
  matching_setup : Callable
1806
- The setup function associated with the name.
488
+ Corresponding setup function.
1807
489
  matching_scoring : Callable
1808
- The scoring function associated with the name.
490
+ Corresponing scoring function.
1809
491
  memory_class : MatchingMemoryUsage
1810
- The custom memory estimation class extending
1811
- :py:class:`tme.matching_memory.MatchingMemoryUsage`.
492
+ Child of :py:class:`tme.memory.MatchingMemoryUsage`.
1812
493
 
1813
494
  Raises
1814
495
  ------
1815
496
  ValueError
1816
- If a function with the name ``matching`` already exists in the registry.
1817
- ValueError
1818
- If ``memory_class`` is not a subclass of
1819
- :py:class:`tme.matching_memory.MatchingMemoryUsage`.
497
+ If a function with the name ``matching`` already exists in the registry, or
498
+ if ``memory_class`` is no child of :py:class:`tme.memory.MatchingMemoryUsage`.
1820
499
  """
1821
500
 
1822
501
  if matching in MATCHING_EXHAUSTIVE_REGISTER: