pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
  2. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
  3. pytme-0.2.3.data/scripts/preprocess.py +132 -0
  4. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
  5. pytme-0.2.3.dist-info/METADATA +92 -0
  6. pytme-0.2.3.dist-info/RECORD +75 -0
  7. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
  8. pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
  9. scripts/extract_candidates.py +20 -13
  10. scripts/match_template.py +219 -216
  11. scripts/match_template_filters.py +154 -95
  12. scripts/postprocess.py +86 -54
  13. scripts/preprocess.py +95 -56
  14. scripts/preprocessor_gui.py +181 -94
  15. scripts/refine_matches.py +265 -61
  16. tme/__init__.py +0 -1
  17. tme/__version__.py +1 -1
  18. tme/analyzer.py +458 -813
  19. tme/backends/__init__.py +40 -11
  20. tme/backends/_jax_utils.py +187 -0
  21. tme/backends/cupy_backend.py +109 -226
  22. tme/backends/jax_backend.py +230 -152
  23. tme/backends/matching_backend.py +445 -384
  24. tme/backends/mlx_backend.py +32 -59
  25. tme/backends/npfftw_backend.py +240 -507
  26. tme/backends/pytorch_backend.py +30 -151
  27. tme/density.py +248 -371
  28. tme/extensions.cpython-311-darwin.so +0 -0
  29. tme/matching_data.py +328 -284
  30. tme/matching_exhaustive.py +195 -1499
  31. tme/matching_optimization.py +143 -106
  32. tme/matching_scores.py +887 -0
  33. tme/matching_utils.py +287 -388
  34. tme/memory.py +377 -0
  35. tme/orientations.py +78 -21
  36. tme/parser.py +3 -4
  37. tme/preprocessing/_utils.py +61 -32
  38. tme/preprocessing/composable_filter.py +7 -4
  39. tme/preprocessing/compose.py +7 -3
  40. tme/preprocessing/frequency_filters.py +49 -39
  41. tme/preprocessing/tilt_series.py +44 -72
  42. tme/preprocessor.py +560 -526
  43. tme/structure.py +491 -188
  44. tme/types.py +5 -3
  45. pytme-0.2.1.dist-info/METADATA +0 -73
  46. pytme-0.2.1.dist-info/RECORD +0 -73
  47. tme/helpers.py +0 -881
  48. tme/matching_constrained.py +0 -195
  49. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
  50. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
@@ -4,1300 +4,57 @@
4
4
 
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
- import os
8
7
  import sys
9
8
  import warnings
10
- from itertools import product
11
- from typing import Callable, Tuple, Dict
9
+ import traceback
12
10
  from functools import wraps
11
+ from itertools import product
12
+ from typing import Callable, Tuple, Dict, Optional
13
+
13
14
  from joblib import Parallel, delayed
14
15
  from multiprocessing.managers import SharedMemoryManager
15
16
 
16
- import numpy as np
17
- from scipy.ndimage import laplace
18
-
19
- from .analyzer import MaxScoreOverRotations
20
- from .matching_utils import (
21
- handle_traceback,
22
- split_numpy_array_slices,
23
- conditional_execute,
24
- )
25
- from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
17
+ from .backends import backend as be
26
18
  from .preprocessing import Compose
27
- from .matching_data import MatchingData
28
- from .backends import backend
29
- from .types import NDArray, CallbackClass
30
-
31
- os.environ["MKL_NUM_THREADS"] = "1"
32
- os.environ["OMP_NUM_THREADS"] = "1"
33
- os.environ["PYFFTW_NUM_THREADS"] = "1"
34
- os.environ["OPENBLAS_NUM_THREADS"] = "1"
35
-
36
-
37
- def _run_inner(backend_name, backend_args, **kwargs):
38
- from tme.backends import backend
39
-
40
- backend.change_backend(backend_name, **backend_args)
41
- return scan(**kwargs)
42
-
43
-
44
- def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
45
- """
46
- Standardizes the values in in template by subtracting the mean and dividing by the
47
- standard deviation based on the elements in mask. Subsequently, the template is
48
- multiplied by the mask.
49
-
50
- Parameters
51
- ----------
52
- template : NDArray
53
- The data array to be normalized. This array is modified in-place.
54
- mask : NDArray
55
- A boolean array of the same shape as `template`. True values indicate the positions in `template`
56
- to consider for normalization.
57
- mask_intensity : float
58
- Mask intensity used to compute expectations.
59
-
60
- References
61
- ----------
62
- .. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
63
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
64
-
65
- Returns
66
- -------
67
- None
68
- This function modifies `template` in-place and does not return any value.
69
- """
70
- masked_mean = backend.sum(backend.multiply(template, mask))
71
- masked_mean = backend.divide(masked_mean, mask_intensity)
72
- masked_std = backend.sum(backend.multiply(backend.square(template), mask))
73
- masked_std = backend.subtract(
74
- masked_std / mask_intensity, backend.square(masked_mean)
75
- )
76
- masked_std = backend.sqrt(backend.maximum(masked_std, 0))
77
-
78
- backend.subtract(template, masked_mean, out=template)
79
- backend.divide(template, masked_std, out=template)
80
- backend.multiply(template, mask, out=template)
81
- return None
82
-
83
-
84
- def _normalize_under_mask_overflow_safe(
85
- template: NDArray, mask: NDArray, mask_intensity
86
- ) -> None:
87
- _template = backend.astype(template, backend._overflow_safe_dtype)
88
- _mask = backend.astype(mask, backend._overflow_safe_dtype)
89
- normalize_under_mask(template=_template, mask=_mask, mask_intensity=mask_intensity)
90
- template[:] = backend.astype(_template, template.dtype)
91
- return None
92
-
93
-
94
- def apply_filter(ft_template, template_filter):
95
- # This is an approximation to applying the mask, irfftn, normalize, rfftn
96
- std_before = backend.std(ft_template)
97
- backend.multiply(ft_template, template_filter, out=ft_template)
98
- backend.multiply(
99
- ft_template, std_before / backend.std(ft_template), out=ft_template
100
- )
101
-
102
-
103
- def cc_setup(
104
- rfftn: Callable,
105
- irfftn: Callable,
106
- template: NDArray,
107
- target: NDArray,
108
- fast_shape: Tuple[int],
109
- fast_ft_shape: Tuple[int],
110
- shared_memory_handler: Callable,
111
- callback_class: Callable,
112
- callback_class_args: Dict,
113
- **kwargs,
114
- ) -> Dict:
115
- """
116
- Setup to compute the cross-correlation between a target f and template g
117
- defined as:
118
-
119
- .. math::
120
-
121
- \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
122
-
123
-
124
- See Also
125
- --------
126
- :py:meth:`corr_scoring`
127
- :py:class:`tme.matching_optimization.CrossCorrelation`
128
- """
129
- real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
130
- target_pad = backend.topleft_pad(target, fast_shape)
131
- target_pad_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
132
-
133
- rfftn(target_pad, target_pad_ft)
134
-
135
- target_ft_out = backend.arr_to_sharedarr(
136
- arr=target_pad_ft, shared_memory_handler=shared_memory_handler
137
- )
138
-
139
- template_out = backend.arr_to_sharedarr(
140
- arr=template, shared_memory_handler=shared_memory_handler
141
- )
142
- inv_denominator_buffer = backend.arr_to_sharedarr(
143
- arr=backend.preallocate_array(1, real_dtype) + 1,
144
- shared_memory_handler=shared_memory_handler,
145
- )
146
- numerator_buffer = backend.arr_to_sharedarr(
147
- arr=backend.preallocate_array(1, real_dtype),
148
- shared_memory_handler=shared_memory_handler,
149
- )
150
-
151
- target_ft_tuple = (target_ft_out, fast_ft_shape, complex_dtype)
152
- template_tuple = (template_out, template.shape, template.dtype)
153
-
154
- inv_denominator_tuple = (inv_denominator_buffer, (1,), real_dtype)
155
- numerator_tuple = (numerator_buffer, (1,), real_dtype)
156
-
157
- ret = {
158
- "template": template_tuple,
159
- "ft_target": target_ft_tuple,
160
- "inv_denominator": inv_denominator_tuple,
161
- "numerator": numerator_tuple,
162
- "targetshape": target.shape,
163
- "templateshape": template.shape,
164
- "fast_shape": fast_shape,
165
- "fast_ft_shape": fast_ft_shape,
166
- "callback_class": callback_class,
167
- "callback_class_args": callback_class_args,
168
- "use_memmap": kwargs.get("use_memmap", False),
169
- "temp_dir": kwargs.get("temp_dir", None),
170
- }
171
-
172
- return ret
173
-
174
-
175
- def lcc_setup(**kwargs) -> Dict:
176
- """
177
- Setup to compute the cross-correlation between a laplace transformed target f
178
- and laplace transformed template g defined as:
179
-
180
- .. math::
181
-
182
- \\mathcal{F}^{-1}(\\mathcal{F}(\\nabla^{2}f) \\cdot \\mathcal{F}(\\nabla^{2} g)^*)
183
-
184
-
185
- See Also
186
- --------
187
- :py:meth:`corr_scoring`
188
- :py:class:`tme.matching_optimization.LaplaceCrossCorrelation`
189
- """
190
- kwargs["target"] = laplace(kwargs["target"], mode="wrap")
191
- kwargs["template"] = laplace(kwargs["template"], mode="wrap")
192
- return cc_setup(**kwargs)
193
-
194
-
195
- def corr_setup(
196
- rfftn: Callable,
197
- irfftn: Callable,
198
- template: NDArray,
199
- template_mask: NDArray,
200
- target: NDArray,
201
- fast_shape: Tuple[int],
202
- fast_ft_shape: Tuple[int],
203
- shared_memory_handler: Callable,
204
- callback_class: Callable,
205
- callback_class_args: Dict,
206
- **kwargs,
207
- ) -> Dict:
208
- """
209
- Setup to compute a normalized cross-correlation between a target f and a template g.
210
-
211
- .. math::
212
-
213
- \\frac{CC(f,g) - \\overline{g} \\cdot CC(f, m)}
214
- {(CC(f^2, m) - \\frac{CC(f, m)^2}{N_g}) \\cdot \\sigma_{g}}
215
-
216
- Where:
217
-
218
- .. math::
219
-
220
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
221
-
222
- and m is a mask with the same dimension as the template filled with ones.
223
-
224
- References
225
- ----------
226
- .. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light
227
- and Magic.
228
-
229
- See Also
230
- --------
231
- :py:meth:`corr_scoring`
232
- :py:class:`tme.matching_optimization.NormalizedCrossCorrelation`.
233
- """
234
- real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
235
- target_pad = backend.topleft_pad(target, fast_shape)
236
-
237
- # The exact composition of the denominator is debatable
238
- # scikit-image match_template multiplies the running sum of the target
239
- # with a scaling factor derived from the template. This is probably appropriate
240
- # in pattern matching situations where the template exists in the target
241
- window_template = backend.topleft_pad(template_mask, fast_shape)
242
- ft_window_template = backend.preallocate_array(fast_ft_shape, complex_dtype)
243
- rfftn(window_template, ft_window_template)
244
- window_template = None
245
-
246
- # Target and squared target window sums
247
- ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
248
- ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
249
- denominator = backend.preallocate_array(fast_shape, real_dtype)
250
- target_window_sum = backend.preallocate_array(fast_shape, real_dtype)
251
- rfftn(target_pad, ft_target)
252
-
253
- rfftn(backend.square(target_pad), ft_target2)
254
- backend.multiply(ft_target2, ft_window_template, out=ft_target2)
255
- irfftn(ft_target2, denominator)
256
-
257
- backend.multiply(ft_target, ft_window_template, out=ft_window_template)
258
- irfftn(ft_window_template, target_window_sum)
259
-
260
- target_pad, ft_target2, ft_window_template = None, None, None
261
-
262
- # Normalizing constants
263
- n_observations = backend.sum(template_mask)
264
- template_mean = backend.sum(backend.multiply(template, template_mask))
265
- template_mean = backend.divide(template_mean, n_observations)
266
- template_ssd = backend.sum(
267
- backend.square(
268
- backend.multiply(backend.subtract(template, template_mean), template_mask)
269
- )
270
- )
271
- template_volume = np.prod(tuple(int(x) for x in template.shape))
272
- backend.multiply(template, template_mask, out=template)
273
-
274
- # Final numerator is score - numerator
275
- numerator = backend.multiply(target_window_sum, template_mean)
276
-
277
- # Compute denominator
278
- backend.multiply(target_window_sum, target_window_sum, out=target_window_sum)
279
- backend.divide(target_window_sum, template_volume, out=target_window_sum)
280
-
281
- backend.subtract(denominator, target_window_sum, out=denominator)
282
- backend.multiply(denominator, template_ssd, out=denominator)
283
- backend.maximum(denominator, 0, out=denominator)
284
- backend.sqrt(denominator, out=denominator)
285
- target_window_sum = None
286
-
287
- # Invert denominator to compute final score as product
288
- denominator_mask = denominator > backend.eps(denominator.dtype)
289
- inv_denominator = backend.preallocate_array(fast_shape, real_dtype)
290
- inv_denominator[denominator_mask] = 1 / denominator[denominator_mask]
291
-
292
- # Convert arrays used in subsequent fitting to SharedMemory objects
293
- template_buffer = backend.arr_to_sharedarr(
294
- arr=template, shared_memory_handler=shared_memory_handler
295
- )
296
- target_ft_buffer = backend.arr_to_sharedarr(
297
- arr=ft_target, shared_memory_handler=shared_memory_handler
298
- )
299
- inv_denominator_buffer = backend.arr_to_sharedarr(
300
- arr=inv_denominator, shared_memory_handler=shared_memory_handler
301
- )
302
- numerator_buffer = backend.arr_to_sharedarr(
303
- arr=numerator, shared_memory_handler=shared_memory_handler
304
- )
305
-
306
- template_tuple = (template_buffer, template.shape, template.dtype)
307
- target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
308
-
309
- inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
310
- numerator_tuple = (numerator_buffer, fast_shape, real_dtype)
311
-
312
- ft_target, inv_denominator, numerator = None, None, None
313
-
314
- ret = {
315
- "template": template_tuple,
316
- "ft_target": target_ft_tuple,
317
- "inv_denominator": inv_denominator_tuple,
318
- "numerator": numerator_tuple,
319
- "targetshape": target.shape,
320
- "templateshape": template.shape,
321
- "fast_shape": fast_shape,
322
- "fast_ft_shape": fast_ft_shape,
323
- "callback_class": callback_class,
324
- "callback_class_args": callback_class_args,
325
- "template_mean": kwargs.get("template_mean", template_mean),
326
- }
327
-
328
- return ret
329
-
330
-
331
- def cam_setup(**kwargs):
332
- """
333
- Setup to compute a normalized cross-correlation between a target f and a template g
334
- over their means. In practice this can be expressed like the cross-correlation
335
- CORR defined in :py:meth:`corr_scoring`, so that:
336
-
337
- Notes
338
- -----
339
-
340
- .. math::
341
-
342
- \\text{CORR}(f-\\overline{f}, g - \\overline{g})
343
-
344
- Where
345
-
346
- .. math::
347
-
348
- \\overline{f}, \\overline{g}
349
-
350
- are the mean of f and g respectively.
351
-
352
- References
353
- ----------
354
- .. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light
355
- and Magic.
356
-
357
- See Also
358
- --------
359
- :py:meth:`corr_scoring`.
360
- :py:class:`tme.matching_optimization.NormalizedCrossCorrelationMean`.
361
- """
362
- kwargs["template"] = kwargs["template"] - kwargs["template"].mean()
363
- return corr_setup(**kwargs)
364
-
365
-
366
- def flc_setup(
367
- rfftn: Callable,
368
- irfftn: Callable,
369
- template: NDArray,
370
- template_mask: NDArray,
371
- target: NDArray,
372
- fast_shape: Tuple[int],
373
- fast_ft_shape: Tuple[int],
374
- shared_memory_handler: Callable,
375
- callback_class: Callable,
376
- callback_class_args: Dict,
377
- **kwargs,
378
- ) -> Dict:
379
- """
380
- Setup to compute a normalized cross-correlation score of a target f a template g
381
- and a mask m:
382
-
383
- .. math::
384
-
385
- \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
386
- {N_m * \\sqrt{
387
- \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
388
- }
389
-
390
- Where:
391
-
392
- .. math::
393
-
394
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
395
-
396
- and Nm is the number of voxels within the template mask m.
397
-
398
- References
399
- ----------
400
- .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
401
- Microsc. Microanal. 26, 2516 (2020)
402
- .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
403
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
404
-
405
- See Also
406
- --------
407
- :py:meth:`flc_scoring`
408
- """
409
- target_pad = backend.topleft_pad(target, fast_shape)
410
-
411
- # Target and squared target window sums
412
- ft_target = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
413
- ft_target2 = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
414
- rfftn(target_pad, ft_target)
415
- backend.square(target_pad, out=target_pad)
416
- rfftn(target_pad, ft_target2)
417
-
418
- # Convert arrays used in subsequent fitting to SharedMemory objects
419
- ft_target = backend.arr_to_sharedarr(
420
- arr=ft_target, shared_memory_handler=shared_memory_handler
421
- )
422
- ft_target2 = backend.arr_to_sharedarr(
423
- arr=ft_target2, shared_memory_handler=shared_memory_handler
424
- )
425
-
426
- normalize_under_mask(
427
- template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
428
- )
429
-
430
- template_buffer = backend.arr_to_sharedarr(
431
- arr=template, shared_memory_handler=shared_memory_handler
432
- )
433
- template_mask_buffer = backend.arr_to_sharedarr(
434
- arr=template_mask, shared_memory_handler=shared_memory_handler
435
- )
436
-
437
- template_tuple = (template_buffer, template.shape, template.dtype)
438
- template_mask_tuple = (
439
- template_mask_buffer,
440
- template_mask.shape,
441
- template_mask.dtype,
442
- )
443
-
444
- target_ft_tuple = (ft_target, fast_ft_shape, backend._complex_dtype)
445
- target_ft2_tuple = (ft_target2, fast_ft_shape, backend._complex_dtype)
446
-
447
- ret = {
448
- "template": template_tuple,
449
- "template_mask": template_mask_tuple,
450
- "ft_target": target_ft_tuple,
451
- "ft_target2": target_ft2_tuple,
452
- "targetshape": target.shape,
453
- "templateshape": template.shape,
454
- "fast_shape": fast_shape,
455
- "fast_ft_shape": fast_ft_shape,
456
- "callback_class": callback_class,
457
- "callback_class_args": callback_class_args,
458
- }
459
-
460
- return ret
461
-
462
-
463
- def flcSphericalMask_setup(
464
- rfftn: Callable,
465
- irfftn: Callable,
466
- template: NDArray,
467
- template_mask: NDArray,
468
- target: NDArray,
469
- fast_shape: Tuple[int],
470
- fast_ft_shape: Tuple[int],
471
- shared_memory_handler: Callable,
472
- callback_class: Callable,
473
- callback_class_args: Dict,
474
- **kwargs,
475
- ) -> Dict:
476
- """
477
- Like :py:meth:`flc_setup` but for rotation invariant masks. In such cases
478
- the score can be computed quicker by not computing the fourier transforms
479
- of the mask for each rotation.
480
-
481
- .. math::
482
-
483
- \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
484
- {N_m * \\sqrt{
485
- \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
486
- }
487
-
488
- Where:
489
-
490
- .. math::
491
-
492
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
493
-
494
- and Nm is the number of voxels within the template mask m.
495
-
496
- References
497
- ----------
498
- .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
499
- Microsc. Microanal. 26, 2516 (2020)
500
- .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
501
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
502
-
503
- See Also
504
- --------
505
- :py:meth:`corr_scoring`
506
- """
507
- real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
508
- target_pad = backend.topleft_pad(target, fast_shape)
509
-
510
- # Target and squared target window sums
511
- ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
512
- ft_template_mask = backend.preallocate_array(fast_ft_shape, complex_dtype)
513
- ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
514
-
515
- temp = backend.preallocate_array(fast_shape, real_dtype)
516
- temp2 = backend.preallocate_array(fast_shape, real_dtype)
517
- numerator = backend.preallocate_array(1, real_dtype)
518
-
519
- n_observations, norm_func = backend.sum(template_mask), normalize_under_mask
520
- if backend.datatype_bytes(template_mask.dtype) == 2:
521
- norm_func = _normalize_under_mask_overflow_safe
522
- n_observations = backend.sum(
523
- backend.astype(template_mask, backend._overflow_safe_dtype)
524
- )
525
-
526
- template_mask_pad = backend.topleft_pad(template_mask, fast_shape)
527
- rfftn(template_mask_pad, ft_template_mask)
528
-
529
- # Denominator E(X^2) - E(X)^2
530
- rfftn(backend.square(target_pad), ft_target)
531
- backend.multiply(ft_target, ft_template_mask, out=ft_temp)
532
- irfftn(ft_temp, temp2)
533
- backend.divide(temp2, n_observations, out=temp2)
534
-
535
- rfftn(target_pad, ft_target)
536
- backend.multiply(ft_target, ft_template_mask, out=ft_temp)
537
- irfftn(ft_temp, temp)
538
- backend.divide(temp, n_observations, out=temp)
539
- backend.square(temp, out=temp)
540
-
541
- backend.subtract(temp2, temp, out=temp)
542
-
543
- backend.maximum(temp, 0.0, out=temp)
544
- backend.sqrt(temp, out=temp)
545
- backend.multiply(temp, n_observations, out=temp)
546
-
547
- backend.fill(temp2, 0)
548
- nonzero_indices = temp > backend.eps(real_dtype)
549
- temp2[nonzero_indices] = 1 / temp[nonzero_indices]
550
-
551
- norm_func(template=template, mask=template_mask, mask_intensity=n_observations)
552
-
553
- template_buffer = backend.arr_to_sharedarr(
554
- arr=template, shared_memory_handler=shared_memory_handler
555
- )
556
- template_mask_buffer = backend.arr_to_sharedarr(
557
- arr=template_mask, shared_memory_handler=shared_memory_handler
558
- )
559
- target_ft_buffer = backend.arr_to_sharedarr(
560
- arr=ft_target, shared_memory_handler=shared_memory_handler
561
- )
562
- inv_denominator_buffer = backend.arr_to_sharedarr(
563
- arr=temp2, shared_memory_handler=shared_memory_handler
564
- )
565
- numerator_buffer = backend.arr_to_sharedarr(
566
- arr=numerator, shared_memory_handler=shared_memory_handler
567
- )
568
-
569
- template_tuple = (template_buffer, template.shape, template.dtype)
570
- template_mask_tuple = (template_mask_buffer, template.shape, template_mask.dtype)
571
- target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
572
-
573
- inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
574
- numerator_tuple = (numerator_buffer, (1,), real_dtype)
575
-
576
- ret = {
577
- "template": template_tuple,
578
- "template_mask": template_mask_tuple,
579
- "ft_target": target_ft_tuple,
580
- "inv_denominator": inv_denominator_tuple,
581
- "numerator": numerator_tuple,
582
- "targetshape": target.shape,
583
- "templateshape": template.shape,
584
- "fast_shape": fast_shape,
585
- "fast_ft_shape": fast_ft_shape,
586
- "callback_class": callback_class,
587
- "callback_class_args": callback_class_args,
588
- }
589
-
590
- return ret
591
-
592
-
593
- def mcc_setup(
594
- rfftn: Callable,
595
- irfftn: Callable,
596
- template: NDArray,
597
- template_mask: NDArray,
598
- target: NDArray,
599
- target_mask: NDArray,
600
- fast_shape: Tuple[int],
601
- fast_ft_shape: Tuple[int],
602
- shared_memory_handler: Callable,
603
- callback_class: Callable,
604
- callback_class_args: Dict,
605
- **kwargs,
606
- ) -> Dict:
607
- """
608
- Setup to compute a normalized cross-correlation score with masks
609
- for both template and target.
610
-
611
- .. math::
612
-
613
- \\frac{
614
- CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
615
- }{
616
- \\sqrt{
617
- (CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
618
- (CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
619
- }
620
- }
621
-
622
- Where:
623
-
624
- .. math::
625
-
626
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
19
+ from .matching_utils import split_shape
20
+ from .matching_scores import MATCHING_EXHAUSTIVE_REGISTER
21
+ from .types import CallbackClass, MatchingData
22
+ from .memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
627
23
 
628
24
 
629
- References
630
- ----------
631
- .. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
632
-
633
- See Also
634
- --------
635
- :py:meth:`mcc_scoring`
636
- :py:class:`tme.matching_optimization.MaskedCrossCorrelation`
25
+ def _handle_traceback(last_type, last_value, last_traceback):
637
26
  """
638
- real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
639
- target = backend.multiply(target, target_mask > 0, out=target)
640
-
641
- target_pad = backend.topleft_pad(target, fast_shape)
642
- target_mask_pad = backend.topleft_pad(target_mask, fast_shape)
643
-
644
- target_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
645
- rfftn(target_pad, target_ft)
646
- target_ft_buffer = backend.arr_to_sharedarr(
647
- arr=target_ft, shared_memory_handler=shared_memory_handler
648
- )
649
-
650
- target_ft2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
651
- rfftn(backend.square(target_pad), target_ft2)
652
- target_ft2_buffer = backend.arr_to_sharedarr(
653
- arr=target_ft2, shared_memory_handler=shared_memory_handler
654
- )
655
-
656
- target_mask_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
657
- rfftn(target_mask_pad, target_mask_ft)
658
- target_mask_ft_buffer = backend.arr_to_sharedarr(
659
- arr=target_mask_ft, shared_memory_handler=shared_memory_handler
660
- )
661
-
662
- template_buffer = backend.arr_to_sharedarr(
663
- arr=template, shared_memory_handler=shared_memory_handler
664
- )
665
- template_mask_buffer = backend.arr_to_sharedarr(
666
- arr=template_mask, shared_memory_handler=shared_memory_handler
667
- )
668
-
669
- template_tuple = (template_buffer, template.shape, template.dtype)
670
- template_mask_tuple = (template_mask_buffer, template.shape, template_mask.dtype)
671
-
672
- target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
673
- target_ft2_tuple = (target_ft2_buffer, fast_ft_shape, complex_dtype)
674
- target_mask_ft_tuple = (target_mask_ft_buffer, fast_ft_shape, complex_dtype)
675
-
676
- ret = {
677
- "template": template_tuple,
678
- "template_mask": template_mask_tuple,
679
- "ft_target": target_ft_tuple,
680
- "ft_target2": target_ft2_tuple,
681
- "ft_target_mask": target_mask_ft_tuple,
682
- "targetshape": target.shape,
683
- "templateshape": template.shape,
684
- "fast_shape": fast_shape,
685
- "fast_ft_shape": fast_ft_shape,
686
- "callback_class": callback_class,
687
- "callback_class_args": callback_class_args,
688
- }
689
-
690
- return ret
691
-
692
-
693
- def corr_scoring(
694
- template: Tuple[type, Tuple[int], type],
695
- ft_target: Tuple[type, Tuple[int], type],
696
- inv_denominator: Tuple[type, Tuple[int], type],
697
- numerator: Tuple[type, Tuple[int], type],
698
- template_filter: Tuple[type, Tuple[int], type],
699
- targetshape: Tuple[int],
700
- templateshape: Tuple[int],
701
- fast_shape: Tuple[int],
702
- fast_ft_shape: Tuple[int],
703
- rotations: NDArray,
704
- callback_class: CallbackClass,
705
- callback_class_args: Dict,
706
- interpolation_order: int,
707
- convolution_mode: str = "full",
708
- **kwargs,
709
- ) -> CallbackClass:
710
- """
711
- Calculates a normalized cross-correlation between a target f and a template g.
712
-
713
- .. math::
714
-
715
- (CC(f,g) - numerator) \\cdot inv\\_denominator
27
+ Handle sys.exc_info().
716
28
 
717
29
  Parameters
718
30
  ----------
719
- template : Tuple[type, Tuple[int], type]
720
- Tuple containing a pointer to the template data, its shape, and its datatype.
721
- ft_target : Tuple[type, Tuple[int], type]
722
- Tuple containing a pointer to the fourier tranform of the target,
723
- its shape, and its datatype.
724
- inv_denominator : Tuple[type, Tuple[int], type]
725
- Tuple containing a pointer to the inverse denominator data, its shape, and its
726
- datatype.
727
- numerator : Tuple[type, Tuple[int], type]
728
- Tuple containing a pointer to the numerator data, its shape, and its datatype.
729
- fast_shape : Tuple[int]
730
- The shape for fast Fourier transform.
731
- fast_ft_shape : Tuple[int]
732
- The shape for fast Fourier transform of the target.
733
- rotations : NDArray
734
- Array containing the rotation matrices to be applied on the template.
735
- real_dtype : type
736
- Data type for the real part of the array.
737
- complex_dtype : type
738
- Data type for the complex part of the array.
739
- callback_class : CallbackClass
740
- A callable class or function for processing the results after each
741
- rotation.
742
- callback_class_args : Dict
743
- Dictionary of arguments to be passed to the callback class if it's
744
- instantiable.
745
- interpolation_order : int
746
- The order of interpolation to be used while rotating the template.
747
- **kwargs :
748
- Additional arguments to be passed to the function.
749
-
750
- Returns
751
- -------
752
- CallbackClass
753
- If callback_class was provided an instance of callback_class otherwise None.
754
-
755
- See Also
756
- --------
757
- :py:meth:`cc_setup`
758
- :py:meth:`corr_setup`
759
- :py:meth:`cam_setup`
760
- :py:meth:`flcSphericalMask_setup`
761
- """
762
- real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
763
-
764
- callback = callback_class
765
- if callback_class is not None and isinstance(callback_class, type):
766
- callback = callback_class(**callback_class_args)
767
-
768
- template_buffer, template_shape, template_dtype = template
769
- template = backend.sharedarr_to_arr(template_buffer, template_shape, template_dtype)
770
- ft_target = backend.sharedarr_to_arr(*ft_target)
771
- inv_denominator = backend.sharedarr_to_arr(*inv_denominator)
772
- numerator = backend.sharedarr_to_arr(*numerator)
773
- template_filter = backend.sharedarr_to_arr(*template_filter)
774
-
775
- norm_func = normalize_under_mask
776
- norm_template, template_mask, mask_sum = False, 1, 1
777
- if "template_mask" in kwargs:
778
- norm_template = True
779
- template_mask = backend.sharedarr_to_arr(*kwargs["template_mask"])
780
- mask_sum = backend.sum(template_mask)
781
- if backend.datatype_bytes(template_mask.dtype) == 2:
782
- norm_func = _normalize_under_mask_overflow_safe
783
- mask_sum = backend.sum(
784
- backend.astype(template_mask, backend._overflow_safe_dtype)
785
- )
786
-
787
- norm_template = conditional_execute(norm_func, norm_template)
788
- norm_numerator = (backend.sum(numerator) != 0) & (backend.size(numerator) != 1)
789
- norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
790
-
791
- norm_denominator = (backend.sum(inv_denominator) != 1) & (
792
- backend.size(inv_denominator) != 1
793
- )
794
- norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
795
- callback_func = conditional_execute(callback, callback_class is not None)
796
- template_filter_func = conditional_execute(
797
- apply_filter, backend.size(template_filter) != 1
798
- )
799
-
800
- arr = backend.preallocate_array(fast_shape, real_dtype)
801
- ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
802
-
803
- rfftn, irfftn = backend.build_fft(
804
- fast_shape=fast_shape,
805
- fast_ft_shape=fast_ft_shape,
806
- real_dtype=real_dtype,
807
- complex_dtype=complex_dtype,
808
- fftargs=kwargs.get("fftargs", {}),
809
- temp_real=arr,
810
- temp_fft=ft_temp,
811
- )
812
-
813
- unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
814
- for index in range(rotations.shape[0]):
815
- rotation = rotations[index]
816
- backend.fill(arr, 0)
817
- backend.rotate_array(
818
- arr=template,
819
- rotation_matrix=rotation,
820
- out=arr,
821
- use_geometric_center=True,
822
- order=interpolation_order,
823
- )
824
- norm_template(arr[unpadded_slice], template_mask, mask_sum)
825
-
826
- rfftn(arr, ft_temp)
827
- template_filter_func(ft_template=ft_temp, template_filter=template_filter)
828
- backend.multiply(ft_target, ft_temp, out=ft_temp)
829
- irfftn(ft_temp, arr)
830
-
831
- norm_func_numerator(arr, numerator, out=arr)
832
- norm_func_denominator(arr, inv_denominator, out=arr)
833
-
834
- callback_func(
835
- arr,
836
- rotation_matrix=rotation,
837
- rotation_index=index,
838
- **callback_class_args,
839
- )
840
-
841
- return callback
842
-
843
-
844
- def flc_scoring(
845
- template: Tuple[type, Tuple[int], type],
846
- template_mask: Tuple[type, Tuple[int], type],
847
- ft_target: Tuple[type, Tuple[int], type],
848
- ft_target2: Tuple[type, Tuple[int], type],
849
- template_filter: Tuple[type, Tuple[int], type],
850
- targetshape: Tuple[int],
851
- templateshape: Tuple[int],
852
- fast_shape: Tuple[int],
853
- fast_ft_shape: Tuple[int],
854
- rotations: NDArray,
855
- callback_class: CallbackClass,
856
- callback_class_args: Dict,
857
- interpolation_order: int,
858
- **kwargs,
859
- ) -> CallbackClass:
860
- """
861
- Computes a normalized cross-correlation score of a target f a template g
862
- and a mask m:
863
-
864
- .. math::
865
-
866
- \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
867
- {N_m * \\sqrt{
868
- \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
869
- }
870
-
871
- Where:
872
-
873
- .. math::
874
-
875
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
876
-
877
- and Nm is the number of voxels within the template mask m.
878
-
879
- References
880
- ----------
881
- .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
882
- Microsc. Microanal. 26, 2516 (2020)
883
- .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
884
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
885
- """
886
- real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
887
-
888
- callback = callback_class
889
- if callback_class is not None and isinstance(callback_class, type):
890
- callback = callback_class(**callback_class_args)
891
-
892
- template = backend.sharedarr_to_arr(*template)
893
- template_mask = backend.sharedarr_to_arr(*template_mask)
894
- ft_target = backend.sharedarr_to_arr(*ft_target)
895
- ft_target2 = backend.sharedarr_to_arr(*ft_target2)
896
- template_filter = backend.sharedarr_to_arr(*template_filter)
897
-
898
- arr = backend.preallocate_array(fast_shape, real_dtype)
899
- temp = backend.preallocate_array(fast_shape, real_dtype)
900
- temp2 = backend.preallocate_array(fast_shape, real_dtype)
901
-
902
- ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
903
- ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
904
-
905
- rfftn, irfftn = backend.build_fft(
906
- fast_shape=fast_shape,
907
- fast_ft_shape=fast_ft_shape,
908
- real_dtype=real_dtype,
909
- complex_dtype=complex_dtype,
910
- fftargs=kwargs.get("fftargs", {}),
911
- temp_real=arr,
912
- temp_fft=ft_temp,
913
- )
914
- eps = backend.eps(real_dtype)
915
- template_filter_func = conditional_execute(
916
- apply_filter, backend.size(template_filter) != 1
917
- )
918
- callback_func = conditional_execute(callback, callback_class is not None)
919
-
920
- unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
921
- for index in range(rotations.shape[0]):
922
- rotation = rotations[index]
923
- backend.fill(arr, 0)
924
- backend.fill(temp, 0)
925
- backend.rotate_array(
926
- arr=template,
927
- arr_mask=template_mask,
928
- rotation_matrix=rotation,
929
- out=arr,
930
- out_mask=temp,
931
- use_geometric_center=True,
932
- order=interpolation_order,
933
- )
934
- # Given the amount of FFTs, might aswell normalize properly
935
- n_observations = backend.sum(temp)
936
-
937
- normalize_under_mask(
938
- template=arr[unpadded_slice],
939
- mask=temp[unpadded_slice],
940
- mask_intensity=n_observations,
941
- )
942
-
943
- rfftn(temp, ft_temp)
944
-
945
- backend.multiply(ft_target, ft_temp, out=ft_denom)
946
- irfftn(ft_denom, temp)
947
- backend.divide(temp, n_observations, out=temp)
948
- backend.square(temp, out=temp)
949
-
950
- backend.multiply(ft_target2, ft_temp, out=ft_denom)
951
- irfftn(ft_denom, temp2)
952
- backend.divide(temp2, n_observations, out=temp2)
953
-
954
- backend.subtract(temp2, temp, out=temp)
955
- backend.maximum(temp, 0.0, out=temp)
956
- backend.sqrt(temp, out=temp)
957
- backend.multiply(temp, n_observations, out=temp)
958
-
959
- rfftn(arr, ft_temp)
960
- template_filter_func(ft_template=ft_temp, template_filter=template_filter)
961
- backend.multiply(ft_target, ft_temp, out=ft_temp)
962
- irfftn(ft_temp, arr)
963
-
964
- tol = eps
965
- # tol = 1e3 * eps * backend.max(backend.abs(temp))
966
- nonzero_indices = temp > tol
967
- backend.fill(temp2, 0)
968
- temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
969
-
970
- callback_func(
971
- temp2,
972
- rotation_matrix=rotation,
973
- rotation_index=index,
974
- **callback_class_args,
975
- )
976
-
977
- return callback
978
-
979
-
980
- def flc_scoring2(
981
- template: Tuple[type, Tuple[int], type],
982
- template_mask: Tuple[type, Tuple[int], type],
983
- ft_target: Tuple[type, Tuple[int], type],
984
- ft_target2: Tuple[type, Tuple[int], type],
985
- template_filter: Tuple[type, Tuple[int], type],
986
- targetshape: Tuple[int],
987
- templateshape: Tuple[int],
988
- fast_shape: Tuple[int],
989
- fast_ft_shape: Tuple[int],
990
- rotations: NDArray,
991
- callback_class: CallbackClass,
992
- callback_class_args: Dict,
993
- interpolation_order: int,
994
- **kwargs,
995
- ) -> CallbackClass:
996
- """
997
- Computes a normalized cross-correlation score of a target f a template g
998
- and a mask m:
999
-
1000
- .. math::
1001
-
1002
- \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
1003
- {N_m * \\sqrt{
1004
- \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
1005
- }
1006
-
1007
- Where:
1008
-
1009
- .. math::
1010
-
1011
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
1012
-
1013
- and Nm is the number of voxels within the template mask m.
1014
-
1015
- References
1016
- ----------
1017
- .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
1018
- Microsc. Microanal. 26, 2516 (2020)
1019
- .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
1020
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
1021
- """
1022
- real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
1023
- callback = callback_class
1024
- if callback_class is not None and isinstance(callback_class, type):
1025
- callback = callback_class(**callback_class_args)
1026
-
1027
- # Retrieve objects from shared memory
1028
- template = backend.sharedarr_to_arr(*template)
1029
- template_mask = backend.sharedarr_to_arr(*template_mask)
1030
- ft_target = backend.sharedarr_to_arr(*ft_target)
1031
- ft_target2 = backend.sharedarr_to_arr(*ft_target2)
1032
- template_filter = backend.sharedarr_to_arr(*template_filter)
1033
-
1034
- arr = backend.preallocate_array(fast_shape, real_dtype)
1035
- temp = backend.preallocate_array(fast_shape, real_dtype)
1036
- temp2 = backend.preallocate_array(fast_shape, real_dtype)
1037
-
1038
- ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
1039
- ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
1040
-
1041
- eps = backend.eps(real_dtype)
1042
- template_filter_func = conditional_execute(
1043
- apply_filter, backend.size(template_filter) != 1
1044
- )
1045
- callback_func = conditional_execute(callback, callback_class is not None)
1046
-
1047
- squeeze_axis = tuple(i for i, x in enumerate(template.shape) if x == 1)
1048
- squeeze = tuple(
1049
- slice(0, stop) if i not in squeeze_axis else 0
1050
- for i, stop in enumerate(template.shape)
1051
- )
1052
- squeeze_fast = tuple(
1053
- slice(0, stop) if i not in squeeze_axis else 0
1054
- for i, stop in enumerate(fast_shape)
1055
- )
1056
- squeeze_fast_ft = tuple(
1057
- slice(0, stop) if i not in squeeze_axis else 0
1058
- for i, stop in enumerate(fast_ft_shape)
1059
- )
1060
-
1061
- rfftn, irfftn = backend.build_fft(
1062
- fast_shape=temp[squeeze_fast].shape,
1063
- fast_ft_shape=fast_ft_shape,
1064
- real_dtype=real_dtype,
1065
- complex_dtype=complex_dtype,
1066
- fftargs=kwargs.get("fftargs", {}),
1067
- inverse_fast_shape=fast_shape,
1068
- temp_real=arr[squeeze_fast],
1069
- temp_fft=ft_temp,
1070
- )
1071
- for index in range(rotations.shape[0]):
1072
- rotation = rotations[index]
1073
- backend.fill(arr, 0)
1074
- backend.fill(temp, 0)
1075
- backend.rotate_array(
1076
- arr=template[squeeze],
1077
- arr_mask=template_mask[squeeze],
1078
- rotation_matrix=rotation,
1079
- out=arr[squeeze],
1080
- out_mask=temp[squeeze],
1081
- use_geometric_center=True,
1082
- order=interpolation_order,
1083
- )
1084
- # Given the amount of FFTs, might aswell normalize properly
1085
- n_observations = backend.sum(temp)
1086
-
1087
- normalize_under_mask(
1088
- template=arr[squeeze],
1089
- mask=temp[squeeze],
1090
- mask_intensity=n_observations,
1091
- )
1092
- rfftn(temp[squeeze_fast], ft_temp[squeeze_fast_ft])
1093
-
1094
- backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
1095
- irfftn(ft_denom, temp)
1096
- backend.divide(temp, n_observations, out=temp)
1097
- backend.square(temp, out=temp)
1098
-
1099
- backend.multiply(ft_target2, ft_temp[squeeze_fast_ft], out=ft_denom)
1100
- irfftn(ft_denom, temp2)
1101
- backend.divide(temp2, n_observations, out=temp2)
1102
-
1103
- backend.subtract(temp2, temp, out=temp)
1104
- backend.maximum(temp, 0.0, out=temp)
1105
- backend.sqrt(temp, out=temp)
1106
- backend.multiply(temp, n_observations, out=temp)
1107
-
1108
- rfftn(arr[squeeze_fast], ft_temp[squeeze_fast_ft])
1109
- template_filter_func(ft_template=ft_temp, template_filter=template_filter)
1110
- backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
1111
- irfftn(ft_denom, arr)
1112
-
1113
- nonzero_indices = temp > eps
1114
- backend.fill(temp2, 0)
1115
- temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
1116
-
1117
- callback_func(
1118
- temp2,
1119
- rotation_matrix=rotation,
1120
- rotation_index=index,
1121
- **callback_class_args,
1122
- )
1123
- return callback
1124
-
1125
-
1126
- def mcc_scoring(
1127
- template: Tuple[type, Tuple[int], type],
1128
- template_mask: Tuple[type, Tuple[int], type],
1129
- ft_target: Tuple[type, Tuple[int], type],
1130
- ft_target2: Tuple[type, Tuple[int], type],
1131
- ft_target_mask: Tuple[type, Tuple[int], type],
1132
- template_filter: Tuple[type, Tuple[int], type],
1133
- targetshape: Tuple[int],
1134
- templateshape: Tuple[int],
1135
- fast_shape: Tuple[int],
1136
- fast_ft_shape: Tuple[int],
1137
- rotations: NDArray,
1138
- callback_class: CallbackClass,
1139
- callback_class_args: type,
1140
- interpolation_order: int,
1141
- overlap_ratio: float = 0.3,
1142
- **kwargs,
1143
- ) -> CallbackClass:
1144
- """
1145
- Computes a cross-correlation score with masks for both template and target.
1146
-
1147
- .. math::
1148
-
1149
- \\frac{
1150
- CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
1151
- }{
1152
- \\sqrt{
1153
- (CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
1154
- (CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
1155
- }
1156
- }
1157
-
1158
- Where:
1159
-
1160
- .. math::
1161
-
1162
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
1163
-
1164
-
1165
- References
1166
- ----------
1167
- .. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
1168
- .. [2] https://scikit-image.org/docs/stable/api/skimage.registration.html
31
+ last_type : type
32
+ The type of the last exception.
33
+ last_value :
34
+ The value of the last exception.
35
+ last_traceback : traceback
36
+ Traceback call stack at the point where the Exception occured.
1169
37
 
1170
- See Also
1171
- --------
1172
- :py:class:`tme.matching_optimization.MaskedCrossCorrelation`
38
+ Raises
39
+ ------
40
+ Exception
41
+ Re-raises the last exception.
1173
42
  """
1174
- real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
1175
- callback = callback_class
1176
- if callback_class is not None and isinstance(callback_class, type):
1177
- callback = callback_class(**callback_class_args)
1178
-
1179
- # Retrieve objects from shared memory
1180
- template_buffer, template_shape, template_dtype = template
1181
- template = backend.sharedarr_to_arr(*template)
1182
- target_ft = backend.sharedarr_to_arr(*ft_target)
1183
- target_ft2 = backend.sharedarr_to_arr(*ft_target2)
1184
- template_mask = backend.sharedarr_to_arr(*template_mask)
1185
- target_mask_ft = backend.sharedarr_to_arr(*ft_target_mask)
1186
- template_filter = backend.sharedarr_to_arr(*template_filter)
1187
-
1188
- axes = tuple(range(template.ndim))
1189
- eps = backend.eps(real_dtype)
1190
-
1191
- # Allocate score and process specific arrays
1192
- template_rot = backend.preallocate_array(fast_shape, real_dtype)
1193
- mask_overlap = backend.preallocate_array(fast_shape, real_dtype)
1194
- numerator = backend.preallocate_array(fast_shape, real_dtype)
1195
-
1196
- temp = backend.preallocate_array(fast_shape, real_dtype)
1197
- temp2 = backend.preallocate_array(fast_shape, real_dtype)
1198
- temp3 = backend.preallocate_array(fast_shape, real_dtype)
1199
- temp_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
1200
-
1201
- rfftn, irfftn = backend.build_fft(
1202
- fast_shape=fast_shape,
1203
- fast_ft_shape=fast_ft_shape,
1204
- real_dtype=real_dtype,
1205
- complex_dtype=complex_dtype,
1206
- fftargs=kwargs.get("fftargs", {}),
1207
- temp_real=numerator,
1208
- temp_fft=temp_ft,
1209
- )
1210
-
1211
- template_filter_func = conditional_execute(
1212
- apply_filter, backend.size(template_filter) != 1
1213
- )
1214
- callback_func = conditional_execute(callback, callback_class is not None)
1215
-
1216
- # Calculate scores across all rotations
1217
- for index in range(rotations.shape[0]):
1218
- rotation = rotations[index]
1219
- backend.fill(template_rot, 0)
1220
- backend.fill(temp, 0)
1221
-
1222
- backend.rotate_array(
1223
- arr=template,
1224
- arr_mask=template_mask,
1225
- rotation_matrix=rotation,
1226
- out=template_rot,
1227
- out_mask=temp,
1228
- use_geometric_center=True,
1229
- order=interpolation_order,
1230
- )
1231
-
1232
- backend.multiply(template_rot, temp > 0, out=template_rot)
1233
-
1234
- # template_rot_ft
1235
- rfftn(template_rot, temp_ft)
1236
- template_filter_func(ft_template=temp_ft, template_filter=template_filter)
1237
- irfftn(target_mask_ft * temp_ft, temp2)
1238
- irfftn(target_ft * temp_ft, numerator)
1239
-
1240
- # temp template_mask_rot | temp_ft template_mask_rot_ft
1241
- # Calculate overlap of masks at every point in the convolution.
1242
- # Locations with high overlap should not be taken into account.
1243
- rfftn(temp, temp_ft)
1244
- irfftn(temp_ft * target_mask_ft, mask_overlap)
1245
- mask_overlap[:] = np.round(mask_overlap)
1246
- mask_overlap[:] = np.maximum(mask_overlap, eps)
1247
- irfftn(temp_ft * target_ft, temp)
1248
-
1249
- backend.subtract(
1250
- numerator,
1251
- backend.divide(backend.multiply(temp, temp2), mask_overlap),
1252
- out=numerator,
1253
- )
43
+ if last_type is None:
44
+ return None
45
+ traceback.print_tb(last_traceback)
46
+ raise Exception(last_value)
1254
47
 
1255
- # temp_3 = fixed_denom
1256
- backend.multiply(temp_ft, target_ft2, out=temp_ft)
1257
- irfftn(temp_ft, temp3)
1258
- backend.subtract(
1259
- temp3, backend.divide(backend.square(temp), mask_overlap), out=temp3
1260
- )
1261
- backend.maximum(temp3, 0.0, out=temp3)
1262
-
1263
- # temp = moving_denom
1264
- rfftn(backend.square(template_rot), temp_ft)
1265
- backend.multiply(target_mask_ft, temp_ft, out=temp_ft)
1266
- irfftn(temp_ft, temp)
1267
-
1268
- backend.subtract(
1269
- temp, backend.divide(backend.square(temp2), mask_overlap), out=temp
1270
- )
1271
- backend.maximum(temp, 0.0, out=temp)
1272
48
 
1273
- # temp_2 = denom
1274
- backend.multiply(temp3, temp, out=temp)
1275
- backend.sqrt(temp, temp2)
1276
-
1277
- # Pixels where `denom` is very small will introduce large
1278
- # numbers after division. To get around this problem,
1279
- # we zero-out problematic pixels.
1280
- tol = 1e3 * eps * backend.max(backend.abs(temp2), axis=axes, keepdims=True)
1281
- nonzero_indices = temp2 > tol
1282
-
1283
- backend.fill(temp, 0)
1284
- temp[nonzero_indices] = numerator[nonzero_indices] / temp2[nonzero_indices]
1285
- backend.clip(temp, a_min=-1, a_max=1, out=temp)
1286
-
1287
- # Apply overlap ratio threshold
1288
- number_px_threshold = overlap_ratio * backend.max(
1289
- mask_overlap, axis=axes, keepdims=True
1290
- )
1291
- temp[mask_overlap < number_px_threshold] = 0.0
49
+ def _wrap_backend(func):
50
+ @wraps(func)
51
+ def wrapper(*args, backend_name: str, backend_args: Dict, **kwargs):
52
+ from tme.backends import backend as be
1292
53
 
1293
- callback_func(
1294
- temp,
1295
- rotation_matrix=rotation,
1296
- rotation_index=index,
1297
- **callback_class_args,
1298
- )
54
+ be.change_backend(backend_name, **backend_args)
55
+ return func(*args, **kwargs)
1299
56
 
1300
- return callback
57
+ return wrapper
1301
58
 
1302
59
 
1303
60
  def _setup_template_filter_apply_target_filter(
@@ -1306,42 +63,59 @@ def _setup_template_filter_apply_target_filter(
1306
63
  irfftn: Callable,
1307
64
  fast_shape: Tuple[int],
1308
65
  fast_ft_shape: Tuple[int],
66
+ pad_template_filter: bool = True,
1309
67
  ):
1310
68
  filter_template = isinstance(matching_data.template_filter, Compose)
1311
69
  filter_target = isinstance(matching_data.target_filter, Compose)
1312
70
 
1313
- template_filter = backend.full(shape=(1,), fill_value=1, dtype=backend._float_dtype)
71
+ template_filter = be.full(shape=(1,), fill_value=1, dtype=be._float_dtype)
1314
72
 
1315
73
  if not filter_template and not filter_target:
1316
74
  return template_filter
1317
75
 
1318
- target_temp = backend.astype(
1319
- backend.topleft_pad(matching_data.target, fast_shape), backend._float_dtype
1320
- )
1321
- target_temp_ft = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
1322
-
1323
- filter_shape = backend.multiply(fast_ft_shape, 1 - matching_data._batch_mask)
1324
- filter_shape[filter_shape == 0] = 1
1325
- fast_shape = backend.multiply(fast_shape, 1 - matching_data._batch_mask)
1326
- fast_shape = fast_shape[fast_shape != 0]
1327
-
1328
- fast_shape = tuple(int(x) for x in fast_shape)
1329
- filter_shape = tuple(int(x) for x in filter_shape)
76
+ inv_mask = be.subtract(1, be.to_backend_array(matching_data._batch_mask))
77
+ filter_shape = be.multiply(be.to_backend_array(fast_ft_shape), inv_mask)
78
+ filter_shape = tuple(int(x) if x != 0 else 1 for x in filter_shape)
79
+ fast_shape = be.multiply(be.to_backend_array(fast_shape), inv_mask)
80
+ fast_shape = tuple(int(x) for x in fast_shape if x != 0)
81
+
82
+ fastt_shape, fastt_ft_shape = fast_shape, filter_shape
83
+ if filter_template and not pad_template_filter:
84
+ # FFT shape acrobatics for faster filter application
85
+ _, fastt_shape, _, _ = matching_data._fourier_padding(
86
+ target_shape=be.to_numpy_array(matching_data._template.shape),
87
+ template_shape=be.to_numpy_array(
88
+ [1 for _ in matching_data._template.shape]
89
+ ),
90
+ pad_fourier=False,
91
+ )
92
+ matching_data.template = be.reverse(
93
+ be.topleft_pad(matching_data.template, fastt_shape)
94
+ )
95
+ matching_data.template_mask = be.reverse(
96
+ be.topleft_pad(matching_data.template_mask, fastt_shape)
97
+ )
98
+ matching_data._set_matching_dimension(
99
+ target_dims=matching_data._target_dims,
100
+ template_dims=matching_data._template_dims,
101
+ )
102
+ fastt_ft_shape = [int(x) for x in matching_data._output_template_shape]
103
+ fastt_ft_shape[-1] = fastt_ft_shape[-1] // 2 + 1
1330
104
 
1331
- rfftn(target_temp, target_temp_ft)
1332
- if isinstance(matching_data.template_filter, Compose):
105
+ target_temp = be.topleft_pad(matching_data.target, fast_shape)
106
+ target_temp_ft = be.zeros(fast_ft_shape, be._complex_dtype)
107
+ target_temp_ft = rfftn(target_temp, target_temp_ft)
108
+ if filter_template:
1333
109
  template_filter = matching_data.template_filter(
1334
- shape=fast_shape,
110
+ shape=fastt_shape,
1335
111
  return_real_fourier=True,
1336
112
  shape_is_real_fourier=False,
1337
113
  data_rfft=target_temp_ft,
1338
114
  batch_dimension=matching_data._target_dims,
1339
- )
1340
- template_filter = template_filter["data"]
1341
- template_filter[tuple(0 for _ in range(template_filter.ndim))] = 0
1342
- template_filter = backend.reshape(template_filter, filter_shape)
115
+ )["data"]
116
+ template_filter = be.reshape(template_filter, fastt_ft_shape)
1343
117
 
1344
- if isinstance(matching_data.target_filter, Compose):
118
+ if filter_target:
1345
119
  target_filter = matching_data.target_filter(
1346
120
  shape=fast_shape,
1347
121
  return_real_fourier=True,
@@ -1349,15 +123,12 @@ def _setup_template_filter_apply_target_filter(
1349
123
  data_rfft=target_temp_ft,
1350
124
  weight_type=None,
1351
125
  batch_dimension=matching_data._target_dims,
1352
- )
1353
- target_filter = target_filter["data"]
1354
- target_filter = backend.reshape(target_filter, filter_shape)
1355
- backend.multiply(target_temp_ft, target_filter, out=target_temp_ft)
126
+ )["data"]
127
+ target_filter = be.reshape(target_filter, filter_shape)
128
+ target_temp_ft = be.multiply(target_temp_ft, target_filter, out=target_temp_ft)
1356
129
 
1357
- irfftn(target_temp_ft, target_temp)
1358
- matching_data._target = backend.topleft_pad(
1359
- target_temp, matching_data.target.shape
1360
- )
130
+ target_temp = irfftn(target_temp_ft, target_temp)
131
+ matching_data._target = be.topleft_pad(target_temp, matching_data.target.shape)
1361
132
 
1362
133
  return template_filter
1363
134
 
@@ -1371,13 +142,14 @@ def device_memory_handler(func: Callable):
1371
142
  last_type, last_value, last_traceback = sys.exc_info()
1372
143
  try:
1373
144
  with SharedMemoryManager() as smh:
1374
- with backend.set_device(kwargs.get("gpu_index", 0)):
145
+ gpu_index = kwargs.pop("gpu_index") if "gpu_index" in kwargs else 0
146
+ with be.set_device(gpu_index):
1375
147
  return_value = func(shared_memory_handler=smh, *args, **kwargs)
1376
148
  except Exception as e:
1377
149
  print(e)
1378
150
  last_type, last_value, last_traceback = sys.exc_info()
1379
151
  finally:
1380
- handle_traceback(last_type, last_value, last_traceback)
152
+ _handle_traceback(last_type, last_value, last_traceback)
1381
153
  return return_value
1382
154
 
1383
155
  return inner_function
@@ -1391,18 +163,20 @@ def scan(
1391
163
  n_jobs: int = 4,
1392
164
  callback_class: CallbackClass = None,
1393
165
  callback_class_args: Dict = {},
1394
- fftargs: Dict = {},
1395
166
  pad_fourier: bool = True,
167
+ pad_template_filter: bool = True,
1396
168
  interpolation_order: int = 3,
1397
169
  jobs_per_callback_class: int = 8,
1398
- **kwargs,
1399
- ) -> Tuple:
170
+ shared_memory_handler=None,
171
+ ) -> Optional[Tuple]:
1400
172
  """
1401
- Perform template matching.
173
+ Run template matching.
174
+
175
+ .. warning:: ``matching_data`` might be altered or destroyed during computation.
1402
176
 
1403
177
  Parameters
1404
178
  ----------
1405
- matching_data : MatchingData
179
+ matching_data : :py:class:`tme.matching_data.MatchingData`
1406
180
  Template matching data.
1407
181
  matching_setup : Callable
1408
182
  Function pointer to setup function.
@@ -1414,21 +188,21 @@ def scan(
1414
188
  Analyzer class pointer to operate on computed scores.
1415
189
  callback_class_args : dict, optional
1416
190
  Arguments passed to the callback_class. Default is an empty dictionary.
1417
- fftargs : dict, optional
1418
- Arguments for the FFT operations. Default is an empty dictionary.
1419
191
  pad_fourier: bool, optional
1420
192
  Whether to pad target and template to the full convolution shape.
193
+ pad_template_filter: bool, optional
194
+ Whether to pad potential template filters to the full convolution shape.
1421
195
  interpolation_order : int, optional
1422
196
  Order of spline interpolation for rotations.
1423
197
  jobs_per_callback_class : int, optional
1424
198
  How many jobs should be processed by a single callback_class instance,
1425
199
  if one is provided.
1426
- **kwargs : various
1427
- Additional keyword arguments.
200
+ shared_memory_handler : type, optional
201
+ Manager for shared memory objects, None by default.
1428
202
 
1429
203
  Returns
1430
204
  -------
1431
- Tuple
205
+ Optional[Tuple]
1432
206
  The merged results from callback_class if provided otherwise None.
1433
207
 
1434
208
  Examples
@@ -1450,159 +224,100 @@ def scan(
1450
224
 
1451
225
  """
1452
226
  matching_data.to_backend()
1453
- shape_diff = backend.subtract(
1454
- matching_data._output_target_shape, matching_data._output_template_shape
1455
- )
1456
- shape_diff = backend.multiply(shape_diff, 1 - matching_data._batch_mask)
1457
-
1458
- if backend.sum(shape_diff < 0) and not pad_fourier:
1459
- warnings.warn(
1460
- "Target is larger than template and Fourier padding is turned off."
1461
- " This can lead to shifted results. You can swap template and target,"
1462
- " zero-pad the target or turn off template centering."
1463
- )
1464
- fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
1465
- pad_fourier=pad_fourier
1466
- )
1467
-
1468
- callback_class_args["fourier_shift"] = fourier_shift
1469
- rfftn, irfftn = backend.build_fft(
227
+ (
228
+ conv_shape,
229
+ fast_shape,
230
+ fast_ft_shape,
231
+ fourier_shift,
232
+ ) = matching_data.fourier_padding(pad_fourier=pad_fourier)
233
+ template_shape = matching_data.template.shape
234
+
235
+ rfftn, irfftn = be.build_fft(
1470
236
  fast_shape=fast_shape,
1471
237
  fast_ft_shape=fast_ft_shape,
1472
- real_dtype=backend._float_dtype,
1473
- complex_dtype=backend._complex_dtype,
1474
- fftargs=fftargs,
238
+ real_dtype=be._float_dtype,
239
+ complex_dtype=be._complex_dtype,
1475
240
  )
1476
-
1477
241
  template_filter = _setup_template_filter_apply_target_filter(
1478
242
  matching_data=matching_data,
1479
243
  rfftn=rfftn,
1480
244
  irfftn=irfftn,
1481
245
  fast_shape=fast_shape,
1482
246
  fast_ft_shape=fast_ft_shape,
247
+ pad_template_filter=pad_template_filter,
1483
248
  )
249
+ template_filter = be.astype(be.to_backend_array(template_filter), be._float_dtype)
1484
250
 
1485
251
  setup = matching_setup(
1486
252
  rfftn=rfftn,
1487
253
  irfftn=irfftn,
1488
254
  template=matching_data.template,
255
+ template_filter=template_filter,
1489
256
  template_mask=matching_data.template_mask,
1490
257
  target=matching_data.target,
1491
258
  target_mask=matching_data.target_mask,
1492
259
  fast_shape=fast_shape,
1493
260
  fast_ft_shape=fast_ft_shape,
1494
- callback_class=callback_class,
1495
- callback_class_args=callback_class_args,
1496
- **kwargs,
261
+ shared_memory_handler=shared_memory_handler,
1497
262
  )
1498
263
  rfftn, irfftn = None, None
1499
-
1500
- template_filter = backend.to_backend_array(template_filter)
1501
- template_filter = backend.astype(template_filter, backend._float_dtype)
1502
- template_filter_buffer = backend.arr_to_sharedarr(
1503
- arr=template_filter,
1504
- shared_memory_handler=kwargs.get("shared_memory_handler", None),
1505
- )
1506
- setup["template_filter"] = (
1507
- template_filter_buffer,
1508
- template_filter.shape,
1509
- template_filter.dtype,
1510
- )
1511
-
1512
- callback_class_args["translation_offset"] = backend.astype(
1513
- matching_data._translation_offset, int
1514
- )
1515
- callback_class_args["thread_safe"] = n_jobs > 1
1516
- callback_class_args["gpu_index"] = kwargs.get("gpu_index", -1)
1517
-
1518
- n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
1519
- callback_class = setup.pop("callback_class", callback_class)
1520
- callback_class_args = setup.pop("callback_class_args", callback_class_args)
1521
- callback_classes = [callback_class for _ in range(n_callback_classes)]
1522
-
1523
- convolution_mode = "same"
1524
- if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
1525
- convolution_mode = "valid"
1526
-
1527
- callback_class_args["fourier_shift"] = fourier_shift
1528
- callback_class_args["convolution_mode"] = convolution_mode
1529
- callback_class_args["targetshape"] = setup["targetshape"]
1530
- callback_class_args["templateshape"] = setup["templateshape"]
1531
-
1532
- if callback_class == MaxScoreOverRotations:
1533
- callback_classes = [
1534
- class_name(
1535
- score_space_shape=fast_shape,
1536
- score_space_dtype=backend._float_dtype,
1537
- shared_memory_handler=kwargs.get("shared_memory_handler", None),
1538
- rotation_space_dtype=backend._int_dtype,
1539
- **callback_class_args,
1540
- )
1541
- for class_name in callback_classes
1542
- ]
1543
-
1544
- matching_data._target, matching_data._template = None, None
1545
- matching_data._target_mask, matching_data._template_mask = None, None
1546
-
1547
- setup["fftargs"] = fftargs.copy()
1548
- convolution_mode = "same"
1549
- if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
1550
- convolution_mode = "valid"
1551
- setup["convolution_mode"] = convolution_mode
1552
264
  setup["interpolation_order"] = interpolation_order
1553
- rotation_list = matching_data._split_rotations_on_jobs(n_jobs)
1554
-
1555
- backend.free_cache()
1556
-
1557
- def _run_scoring(backend_name, backend_args, rotations, **kwargs):
1558
- from tme.backends import backend
265
+ setup["template_filter"] = be.to_sharedarr(template_filter, shared_memory_handler)
266
+
267
+ offset = be.to_backend_array(matching_data._translation_offset)
268
+ convmode = "valid" if getattr(matching_data, "_is_padded", False) else "same"
269
+ default_callback_args = {
270
+ "offset": be.astype(offset, be._int_dtype),
271
+ "thread_safe": n_jobs > 1,
272
+ "fourier_shift": fourier_shift,
273
+ "convolution_mode": convmode,
274
+ "targetshape": matching_data.target.shape,
275
+ "templateshape": template_shape,
276
+ "convolution_shape": conv_shape,
277
+ "fast_shape": fast_shape,
278
+ "indices": getattr(matching_data, "indices", None),
279
+ "shared_memory_handler": shared_memory_handler,
280
+ "only_unique_rotations": True,
281
+ }
282
+ default_callback_args.update(callback_class_args)
1559
283
 
1560
- backend.change_backend(backend_name, **backend_args)
1561
- return matching_score(rotations=rotations, **kwargs)
284
+ matching_data._free_data()
285
+ be.free_cache()
1562
286
 
287
+ # For highly parallel jobs, blocking in certain analyzers becomes a bottleneck
288
+ if getattr(callback_class, "shared", True):
289
+ jobs_per_callback_class = 1
290
+ n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
291
+ callback_classes = [
292
+ callback_class(
293
+ shape=fast_shape,
294
+ **default_callback_args,
295
+ )
296
+ if callback_class is not None
297
+ else None
298
+ for _ in range(n_callback_classes)
299
+ ]
1563
300
  callbacks = Parallel(n_jobs=n_jobs)(
1564
- delayed(_run_scoring)(
1565
- backend_name=backend._backend_name,
1566
- backend_args=backend._backend_args,
301
+ delayed(_wrap_backend(matching_score))(
302
+ backend_name=be._backend_name,
303
+ backend_args=be._backend_args,
1567
304
  rotations=rotation,
1568
- callback_class=callback_classes[index % n_callback_classes],
1569
- callback_class_args=callback_class_args,
305
+ callback=callback_classes[index % n_callback_classes],
1570
306
  **setup,
1571
307
  )
1572
- for index, rotation in enumerate(rotation_list)
308
+ for index, rotation in enumerate(matching_data._split_rotations_on_jobs(n_jobs))
1573
309
  )
1574
310
 
1575
- callbacks = callbacks[0:n_callback_classes]
1576
311
  callbacks = [
1577
- tuple(
1578
- callback._postprocess(
1579
- fourier_shift=fourier_shift,
1580
- convolution_mode=convolution_mode,
1581
- targetshape=setup["targetshape"],
1582
- templateshape=setup["templateshape"],
1583
- shared_memory_handler=kwargs.get("shared_memory_handler", None),
1584
- )
1585
- )
1586
- if hasattr(callback, "_postprocess")
1587
- else tuple(callback)
312
+ tuple(callback._postprocess(**default_callback_args))
1588
313
  for callback in callbacks
1589
314
  if callback is not None
1590
315
  ]
1591
- backend.free_cache()
316
+ be.free_cache()
1592
317
 
1593
- merged_callback = None
1594
318
  if callback_class is not None:
1595
- score_indices = None
1596
- if hasattr(matching_data, "indices"):
1597
- score_indices = matching_data.indices
1598
- merged_callback = callback_class.merge(
1599
- callbacks,
1600
- **callback_class_args,
1601
- score_indices=score_indices,
1602
- inner_merge=True,
1603
- )
1604
-
1605
- return merged_callback
319
+ return callback_class.merge(callbacks, **default_callback_args)
320
+ return None
1606
321
 
1607
322
 
1608
323
  def scan_subsets(
@@ -1616,10 +331,12 @@ def scan_subsets(
1616
331
  template_splits: Dict = {},
1617
332
  pad_target_edges: bool = False,
1618
333
  pad_fourier: bool = True,
334
+ pad_template_filter: bool = True,
1619
335
  interpolation_order: int = 3,
1620
336
  jobs_per_callback_class: int = 8,
1621
- **kwargs,
1622
- ) -> Tuple:
337
+ backend_name: str = None,
338
+ backend_args: Dict = {},
339
+ ) -> Optional[Tuple]:
1623
340
  """
1624
341
  Wrapper around :py:meth:`scan` that supports matching on splits
1625
342
  of ``matching_data``.
@@ -1651,21 +368,17 @@ def scan_subsets(
1651
368
  along each axis.
1652
369
  pad_fourier: bool, optional
1653
370
  Whether to pad target and template to the full convolution shape.
371
+ pad_template_filter: bool, optional
372
+ Whether to pad potential template filters to the full convolution shape.
1654
373
  interpolation_order : int, optional
1655
374
  Order of spline interpolation for rotations.
1656
375
  jobs_per_callback_class : int, optional
1657
376
  How many jobs should be processed by a single callback_class instance,
1658
377
  if ones is provided.
1659
- **kwargs : various
1660
- Additional arguments.
1661
-
1662
- Notes
1663
- -----
1664
- Objects in matching_data might be destroyed during computation.
1665
378
 
1666
379
  Returns
1667
380
  -------
1668
- Tuple
381
+ Optional[Tuple]
1669
382
  The merged results from callback_class if provided otherwise None.
1670
383
 
1671
384
  Examples
@@ -1720,73 +433,59 @@ def scan_subsets(
1720
433
  >>> target_splits = target_splits,
1721
434
  >>> )
1722
435
 
1723
- The retuned ``results`` tuple contains the output of the chosen analyzer.
436
+ The ``results`` tuple contains the output of the chosen analyzer.
1724
437
 
1725
438
  See Also
1726
439
  --------
1727
440
  :py:meth:`tme.matching_utils.compute_parallelization_schedule`
1728
441
  """
1729
- target_splits = split_numpy_array_slices(
1730
- matching_data._target.shape, splits=target_splits
1731
- )
442
+ template_splits = split_shape(matching_data._template.shape, splits=template_splits)
443
+ target_splits = split_shape(matching_data._target.shape, splits=target_splits)
1732
444
  if (len(target_splits) > 1) and not pad_target_edges:
1733
445
  warnings.warn(
1734
446
  "Target splitting without padding target edges leads to unreliable "
1735
447
  "similarity estimates around the split border."
1736
448
  )
1737
-
1738
- template_splits = split_numpy_array_slices(
1739
- matching_data._template.shape, splits=template_splits
1740
- )
1741
- target_pad = matching_data.target_padding(pad_target=pad_target_edges)
449
+ splits = tuple(product(target_splits, template_splits))
1742
450
 
1743
451
  outer_jobs, inner_jobs = job_schedule
1744
- results = Parallel(n_jobs=outer_jobs)(
1745
- delayed(_run_inner)(
1746
- backend_name=backend._backend_name,
1747
- backend_args=backend._backend_args,
1748
- matching_data=matching_data.subset_by_slice(
1749
- target_slice=target_split,
1750
- target_pad=target_pad,
1751
- template_slice=template_split,
1752
- ),
1753
- matching_score=matching_score,
1754
- matching_setup=matching_setup,
1755
- n_jobs=inner_jobs,
452
+ target_pad = matching_data.target_padding(pad_target=pad_target_edges)
453
+ if hasattr(be, "scan"):
454
+ corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
455
+ results = be.scan(
456
+ matching_data=matching_data,
457
+ splits=splits,
458
+ n_jobs=outer_jobs,
459
+ rotate_mask=matching_score != corr_scoring,
1756
460
  callback_class=callback_class,
1757
- callback_class_args=callback_class_args,
1758
- interpolation_order=interpolation_order,
1759
- pad_fourier=pad_fourier,
1760
- gpu_index=index % outer_jobs,
1761
- **kwargs,
1762
461
  )
1763
- for index, (target_split, template_split) in enumerate(
1764
- product(target_splits, template_splits)
462
+ else:
463
+ results = Parallel(n_jobs=outer_jobs)(
464
+ delayed(_wrap_backend(scan))(
465
+ backend_name=be._backend_name,
466
+ backend_args=be._backend_args,
467
+ matching_data=matching_data.subset_by_slice(
468
+ target_slice=target_split,
469
+ target_pad=target_pad,
470
+ template_slice=template_split,
471
+ ),
472
+ matching_score=matching_score,
473
+ matching_setup=matching_setup,
474
+ n_jobs=inner_jobs,
475
+ callback_class=callback_class,
476
+ callback_class_args=callback_class_args,
477
+ interpolation_order=interpolation_order,
478
+ pad_fourier=pad_fourier,
479
+ gpu_index=index % outer_jobs,
480
+ pad_template_filter=pad_template_filter,
481
+ )
482
+ for index, (target_split, template_split) in enumerate(splits)
1765
483
  )
1766
- )
1767
-
1768
- matching_data._target, matching_data._template = None, None
1769
- matching_data._target_mask, matching_data._template_mask = None, None
1770
484
 
1771
- candidates = None
485
+ matching_data._free_data()
1772
486
  if callback_class is not None:
1773
- candidates = callback_class.merge(
1774
- results, **callback_class_args, inner_merge=False
1775
- )
1776
-
1777
- return candidates
1778
-
1779
-
1780
- MATCHING_EXHAUSTIVE_REGISTER = {
1781
- "CC": (cc_setup, corr_scoring),
1782
- "LCC": (lcc_setup, corr_scoring),
1783
- "CORR": (corr_setup, corr_scoring),
1784
- "CAM": (cam_setup, corr_scoring),
1785
- "FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
1786
- "FLC": (flc_setup, flc_scoring),
1787
- "FLC2": (flc_setup, flc_scoring2),
1788
- "MCC": (mcc_setup, mcc_scoring),
1789
- }
487
+ return callback_class.merge(results, **callback_class_args)
488
+ return None
1790
489
 
1791
490
 
1792
491
  def register_matching_exhaustive(
@@ -1803,20 +502,17 @@ def register_matching_exhaustive(
1803
502
  matching : str
1804
503
  Name of the matching method.
1805
504
  matching_setup : Callable
1806
- The setup function associated with the name.
505
+ Corresponding setup function.
1807
506
  matching_scoring : Callable
1808
- The scoring function associated with the name.
507
+ Corresponing scoring function.
1809
508
  memory_class : MatchingMemoryUsage
1810
- The custom memory estimation class extending
1811
- :py:class:`tme.matching_memory.MatchingMemoryUsage`.
509
+ Child of :py:class:`tme.memory.MatchingMemoryUsage`.
1812
510
 
1813
511
  Raises
1814
512
  ------
1815
513
  ValueError
1816
- If a function with the name ``matching`` already exists in the registry.
1817
- ValueError
1818
- If ``memory_class`` is not a subclass of
1819
- :py:class:`tme.matching_memory.MatchingMemoryUsage`.
514
+ If a function with the name ``matching`` already exists in the registry, or
515
+ if ``memory_class`` is no child of :py:class:`tme.memory.MatchingMemoryUsage`.
1820
516
  """
1821
517
 
1822
518
  if matching in MATCHING_EXHAUSTIVE_REGISTER: