pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. pytme-0.2.2.data/scripts/match_template.py +1187 -0
  2. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
  3. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +126 -87
  8. scripts/match_template.py +596 -209
  9. scripts/match_template_filters.py +571 -223
  10. scripts/postprocess.py +170 -71
  11. scripts/preprocessor_gui.py +179 -86
  12. scripts/refine_matches.py +567 -159
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +627 -855
  16. tme/backends/__init__.py +41 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +120 -225
  19. tme/backends/jax_backend.py +282 -0
  20. tme/backends/matching_backend.py +464 -388
  21. tme/backends/mlx_backend.py +45 -68
  22. tme/backends/npfftw_backend.py +256 -514
  23. tme/backends/pytorch_backend.py +41 -154
  24. tme/density.py +312 -421
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +366 -303
  27. tme/matching_exhaustive.py +279 -1521
  28. tme/matching_optimization.py +234 -129
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +281 -387
  31. tme/memory.py +377 -0
  32. tme/orientations.py +226 -66
  33. tme/parser.py +3 -4
  34. tme/preprocessing/__init__.py +2 -0
  35. tme/preprocessing/_utils.py +217 -0
  36. tme/preprocessing/composable_filter.py +31 -0
  37. tme/preprocessing/compose.py +55 -0
  38. tme/preprocessing/frequency_filters.py +388 -0
  39. tme/preprocessing/tilt_series.py +1011 -0
  40. tme/preprocessor.py +574 -530
  41. tme/structure.py +495 -189
  42. tme/types.py +5 -3
  43. pytme-0.2.0b0.data/scripts/match_template.py +0 -800
  44. pytme-0.2.0b0.dist-info/METADATA +0 -73
  45. pytme-0.2.0b0.dist-info/RECORD +0 -66
  46. tme/helpers.py +0 -881
  47. tme/matching_constrained.py +0 -195
  48. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  49. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  50. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
@@ -4,1357 +4,119 @@
4
4
 
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
- import os
8
7
  import sys
9
8
  import warnings
10
- from copy import deepcopy
11
- from itertools import product
12
- from typing import Callable, Tuple, Dict
9
+ import traceback
13
10
  from functools import wraps
11
+ from itertools import product
12
+ from typing import Callable, Tuple, Dict, Optional
13
+
14
14
  from joblib import Parallel, delayed
15
15
  from multiprocessing.managers import SharedMemoryManager
16
16
 
17
- import numpy as np
18
- from scipy.ndimage import laplace
19
-
20
- from .analyzer import MaxScoreOverRotations
21
- from .matching_utils import (
22
- apply_convolution_mode,
23
- handle_traceback,
24
- split_numpy_array_slices,
25
- conditional_execute,
26
- )
27
- from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
28
- # from .preprocessing import Compose
29
- from .preprocessor import Preprocessor
30
- from .matching_data import MatchingData
31
- from .backends import backend
32
- from .types import NDArray, CallbackClass
33
-
34
- os.environ["MKL_NUM_THREADS"] = "1"
35
- os.environ["OMP_NUM_THREADS"] = "1"
36
- os.environ["PYFFTW_NUM_THREADS"] = "1"
37
- os.environ["OPENBLAS_NUM_THREADS"] = "1"
38
-
39
-
40
- def _run_inner(backend_name, backend_args, **kwargs):
41
- from tme.backends import backend
17
+ from .backends import backend as be
18
+ from .preprocessing import Compose
19
+ from .matching_utils import split_shape
20
+ from .matching_scores import MATCHING_EXHAUSTIVE_REGISTER
21
+ from .types import CallbackClass, MatchingData
22
+ from .memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
42
23
 
43
- backend.change_backend(backend_name, **backend_args)
44
- return scan(**kwargs)
45
24
 
46
-
47
- def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
25
+ def _handle_traceback(last_type, last_value, last_traceback):
48
26
  """
49
- Standardizes the values in in template by subtracting the mean and dividing by the
50
- standard deviation based on the elements in mask. Subsequently, the template is
51
- multiplied by the mask.
27
+ Handle sys.exc_info().
52
28
 
53
29
  Parameters
54
30
  ----------
55
- template : NDArray
56
- The data array to be normalized. This array is modified in-place.
57
- mask : NDArray
58
- A boolean array of the same shape as `template`. True values indicate the positions in `template`
59
- to consider for normalization.
60
- mask_intensity : float
61
- Mask intensity used to compute expectations.
62
-
63
- References
64
- ----------
65
- .. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
66
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
67
- .. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
68
- R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
69
- 13375 (2023)
70
-
71
- Returns
72
- -------
73
- None
74
- This function modifies `template` in-place and does not return any value.
75
- """
76
- masked_mean = backend.sum(backend.multiply(template, mask))
77
- masked_mean = backend.divide(masked_mean, mask_intensity)
78
- masked_std = backend.sum(backend.multiply(backend.square(template), mask))
79
- masked_std = backend.subtract(
80
- masked_std / mask_intensity, backend.square(masked_mean)
81
- )
82
- masked_std = backend.sqrt(backend.maximum(masked_std, 0))
83
-
84
- backend.subtract(template, masked_mean, out=template)
85
- backend.divide(template, masked_std, out=template)
86
- backend.multiply(template, mask, out=template)
87
- return None
88
-
89
-
90
- def apply_filter(ft_template, template_filter):
91
- # This is an approximation to applying the mask, irfftn, normalize, rfftn
92
- std_before = backend.std(ft_template)
93
- backend.multiply(ft_template, template_filter, out=ft_template)
94
- backend.multiply(
95
- ft_template, std_before / backend.std(ft_template), out=ft_template
96
- )
97
-
98
-
99
- def cc_setup(
100
- rfftn: Callable,
101
- irfftn: Callable,
102
- template: NDArray,
103
- target: NDArray,
104
- fast_shape: Tuple[int],
105
- fast_ft_shape: Tuple[int],
106
- real_dtype: type,
107
- complex_dtype: type,
108
- shared_memory_handler: Callable,
109
- callback_class: Callable,
110
- callback_class_args: Dict,
111
- **kwargs,
112
- ) -> Dict:
113
- """
114
- Setup to compute the cross-correlation between a target f and template g
115
- defined as:
116
-
117
- .. math::
118
-
119
- \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
120
-
121
-
122
- See Also
123
- --------
124
- :py:meth:`corr_scoring`
125
- :py:class:`tme.matching_optimization.CrossCorrelation`
126
- """
127
- target_shape = target.shape
128
- target_pad = backend.topleft_pad(target, fast_shape)
129
- target_pad_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
130
-
131
- rfftn(target_pad, target_pad_ft)
132
-
133
- target_ft_out = backend.arr_to_sharedarr(
134
- arr=target_pad_ft, shared_memory_handler=shared_memory_handler
135
- )
136
-
137
- template_out = backend.arr_to_sharedarr(
138
- arr=template, shared_memory_handler=shared_memory_handler
139
- )
140
- inv_denominator_buffer = backend.arr_to_sharedarr(
141
- arr=backend.preallocate_array(1, real_dtype) + 1,
142
- shared_memory_handler=shared_memory_handler,
143
- )
144
- numerator2_buffer = backend.arr_to_sharedarr(
145
- arr=backend.preallocate_array(1, real_dtype),
146
- shared_memory_handler=shared_memory_handler,
147
- )
148
-
149
- target_ft_tuple = (target_ft_out, fast_ft_shape, complex_dtype)
150
- template_tuple = (template_out, template.shape, real_dtype)
151
-
152
- inv_denominator_tuple = (inv_denominator_buffer, (1,), real_dtype)
153
- numerator2_tuple = (numerator2_buffer, (1,), real_dtype)
154
-
155
- ret = {
156
- "template": template_tuple,
157
- "ft_target": target_ft_tuple,
158
- "inv_denominator": inv_denominator_tuple,
159
- "numerator2": numerator2_tuple,
160
- "targetshape": target_shape,
161
- "templateshape": template.shape,
162
- "fast_shape": fast_shape,
163
- "fast_ft_shape": fast_ft_shape,
164
- "real_dtype": real_dtype,
165
- "complex_dtype": complex_dtype,
166
- "callback_class": callback_class,
167
- "callback_class_args": callback_class_args,
168
- "use_memmap": kwargs.get("use_memmap", False),
169
- "temp_dir": kwargs.get("temp_dir", None),
170
- }
171
-
172
- return ret
173
-
174
-
175
- def lcc_setup(**kwargs) -> Dict:
176
- """
177
- Setup to compute the cross-correlation between a laplace transformed target f
178
- and laplace transformed template g defined as:
179
-
180
- .. math::
181
-
182
- \\mathcal{F}^{-1}(\\mathcal{F}(\\nabla^{2}f) \\cdot \\mathcal{F}(\\nabla^{2} g)^*)
183
-
184
-
185
- See Also
186
- --------
187
- :py:meth:`corr_scoring`
188
- :py:class:`tme.matching_optimization.LaplaceCrossCorrelation`
189
- """
190
- kwargs["target"] = laplace(kwargs["target"], mode="wrap")
191
- kwargs["template"] = laplace(kwargs["template"], mode="wrap")
192
- return cc_setup(**kwargs)
193
-
194
-
195
- def corr_setup(
196
- rfftn: Callable,
197
- irfftn: Callable,
198
- template: NDArray,
199
- template_mask: NDArray,
200
- target: NDArray,
201
- fast_shape: Tuple[int],
202
- fast_ft_shape: Tuple[int],
203
- real_dtype: type,
204
- complex_dtype: type,
205
- shared_memory_handler: Callable,
206
- callback_class: Callable,
207
- callback_class_args: Dict,
208
- **kwargs,
209
- ) -> Dict:
210
- """
211
- Setup to compute a normalized cross-correlation between a target f and a template g.
212
-
213
- .. math::
214
-
215
- \\frac{CC(f,g) - \\overline{g} \\cdot CC(f, m)}
216
- {(CC(f^2, m) - \\frac{CC(f, m)^2}{N_g}) \\cdot \\sigma_{g}}
217
-
218
- Where:
219
-
220
- .. math::
221
-
222
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
223
-
224
- and m is a mask with the same dimension as the template filled with ones.
225
-
226
- References
227
- ----------
228
- .. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light
229
- and Magic.
230
-
231
- See Also
232
- --------
233
- :py:meth:`corr_scoring`
234
- :py:class:`tme.matching_optimization.NormalizedCrossCorrelation`.
235
- """
236
- target_pad = backend.topleft_pad(target, fast_shape)
237
-
238
- # The exact composition of the denominator is debatable
239
- # scikit-image match_template multiplies the running sum of the target
240
- # with a scaling factor derived from the template. This is probably appropriate
241
- # in pattern matching situations where the template exists in the target
242
- window_template = backend.topleft_pad(template_mask, fast_shape)
243
- ft_window_template = backend.preallocate_array(fast_ft_shape, complex_dtype)
244
- rfftn(window_template, ft_window_template)
245
- window_template = None
246
-
247
- # Target and squared target window sums
248
- ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
249
- ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
250
- denominator = backend.preallocate_array(fast_shape, real_dtype)
251
- target_window_sum = backend.preallocate_array(fast_shape, real_dtype)
252
- rfftn(target_pad, ft_target)
253
-
254
- rfftn(backend.square(target_pad), ft_target2)
255
- backend.multiply(ft_target2, ft_window_template, out=ft_target2)
256
- irfftn(ft_target2, denominator)
257
-
258
- backend.multiply(ft_target, ft_window_template, out=ft_window_template)
259
- irfftn(ft_window_template, target_window_sum)
260
-
261
- target_pad, ft_target2, ft_window_template = None, None, None
262
-
263
- # Normalizing constants
264
- n_observations = backend.sum(template_mask)
265
- template_mean = backend.sum(backend.multiply(template, template_mask))
266
- template_mean = backend.divide(template_mean, n_observations)
267
- template_ssd = backend.sum(
268
- backend.square(
269
- backend.multiply(backend.multiply(template, template_mean), template_mask)
270
- )
271
- )
272
- template_volume = np.prod(template.shape)
273
- backend.multiply(template, template_mask, out=template)
274
-
275
- # Final numerator is score - numerator2
276
- numerator2 = backend.multiply(target_window_sum, template_mean)
277
-
278
- # Compute denominator
279
- backend.multiply(target_window_sum, target_window_sum, out=target_window_sum)
280
- backend.divide(target_window_sum, template_volume, out=target_window_sum)
281
-
282
- backend.subtract(denominator, target_window_sum, out=denominator)
283
- backend.multiply(denominator, template_ssd, out=denominator)
284
- backend.maximum(denominator, 0, out=denominator)
285
- backend.sqrt(denominator, out=denominator)
286
- target_window_sum = None
287
-
288
- # Invert denominator to compute final score as product
289
- denominator_mask = denominator > backend.eps(denominator.dtype)
290
- inv_denominator = backend.preallocate_array(fast_shape, real_dtype)
291
- inv_denominator[denominator_mask] = 1 / denominator[denominator_mask]
292
-
293
- # Convert arrays used in subsequent fitting to SharedMemory objects
294
- template_buffer = backend.arr_to_sharedarr(
295
- arr=template, shared_memory_handler=shared_memory_handler
296
- )
297
- target_ft_buffer = backend.arr_to_sharedarr(
298
- arr=ft_target, shared_memory_handler=shared_memory_handler
299
- )
300
- inv_denominator_buffer = backend.arr_to_sharedarr(
301
- arr=inv_denominator, shared_memory_handler=shared_memory_handler
302
- )
303
- numerator2_buffer = backend.arr_to_sharedarr(
304
- arr=numerator2, shared_memory_handler=shared_memory_handler
305
- )
306
-
307
- template_tuple = (template_buffer, deepcopy(template.shape), real_dtype)
308
- target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
309
-
310
- inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
311
- numerator2_tuple = (numerator2_buffer, fast_shape, real_dtype)
312
-
313
- ft_target, inv_denominator, numerator2 = None, None, None
314
-
315
- ret = {
316
- "template": template_tuple,
317
- "ft_target": target_ft_tuple,
318
- "inv_denominator": inv_denominator_tuple,
319
- "numerator2": numerator2_tuple,
320
- "targetshape": deepcopy(target.shape),
321
- "templateshape": deepcopy(template.shape),
322
- "fast_shape": fast_shape,
323
- "fast_ft_shape": fast_ft_shape,
324
- "real_dtype": real_dtype,
325
- "complex_dtype": complex_dtype,
326
- "callback_class": callback_class,
327
- "callback_class_args": callback_class_args,
328
- "template_mean": kwargs.get("template_mean", template_mean),
329
- }
330
-
331
- return ret
332
-
333
-
334
- def cam_setup(**kwargs):
335
- """
336
- Setup to compute a normalized cross-correlation between a target f and a template g
337
- over their means. In practice this can be expressed like the cross-correlation
338
- CORR defined in :py:meth:`corr_scoring`, so that:
339
-
340
- Notes
341
- -----
342
-
343
- .. math::
344
-
345
- \\text{CORR}(f-\\overline{f}, g - \\overline{g})
346
-
347
- Where
348
-
349
- .. math::
350
-
351
- \\overline{f}, \\overline{g}
352
-
353
- are the mean of f and g respectively.
354
-
355
- References
356
- ----------
357
- .. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light
358
- and Magic.
359
-
360
- See Also
361
- --------
362
- :py:meth:`corr_scoring`.
363
- :py:class:`tme.matching_optimization.NormalizedCrossCorrelationMean`.
364
- """
365
- kwargs["template"] = kwargs["template"] - kwargs["template"].mean()
366
- return corr_setup(**kwargs)
367
-
368
-
369
- def flc_setup(
370
- rfftn: Callable,
371
- irfftn: Callable,
372
- template: NDArray,
373
- template_mask: NDArray,
374
- target: NDArray,
375
- fast_shape: Tuple[int],
376
- fast_ft_shape: Tuple[int],
377
- real_dtype: type,
378
- complex_dtype: type,
379
- shared_memory_handler: Callable,
380
- callback_class: Callable,
381
- callback_class_args: Dict,
382
- **kwargs,
383
- ) -> Dict:
384
- """
385
- Setup to compute a normalized cross-correlation score of a target f a template g
386
- and a mask m:
387
-
388
- .. math::
389
-
390
- \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
391
- {N_m * \\sqrt{
392
- \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
393
- }
394
-
395
- Where:
396
-
397
- .. math::
398
-
399
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
400
-
401
- and Nm is the number of voxels within the template mask m.
402
-
403
- References
404
- ----------
405
- .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
406
- Microsc. Microanal. 26, 2516 (2020)
407
- .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
408
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
409
-
410
- See Also
411
- --------
412
- :py:meth:`flc_scoring`
413
- """
414
- target_pad = backend.topleft_pad(target, fast_shape)
415
-
416
- # Target and squared target window sums
417
- ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
418
- ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
419
- rfftn(target_pad, ft_target)
420
- backend.square(target_pad, out=target_pad)
421
- rfftn(target_pad, ft_target2)
422
-
423
- # Convert arrays used in subsequent fitting to SharedMemory objects
424
- ft_target = backend.arr_to_sharedarr(
425
- arr=ft_target, shared_memory_handler=shared_memory_handler
426
- )
427
- ft_target2 = backend.arr_to_sharedarr(
428
- arr=ft_target2, shared_memory_handler=shared_memory_handler
429
- )
430
-
431
- normalize_under_mask(
432
- template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
433
- )
434
-
435
- template_buffer = backend.arr_to_sharedarr(
436
- arr=template, shared_memory_handler=shared_memory_handler
437
- )
438
- template_mask_buffer = backend.arr_to_sharedarr(
439
- arr=template_mask, shared_memory_handler=shared_memory_handler
440
- )
441
-
442
- template_tuple = (template_buffer, template.shape, real_dtype)
443
- template_mask_tuple = (template_mask_buffer, template_mask.shape, real_dtype)
444
-
445
- target_ft_tuple = (ft_target, fast_ft_shape, complex_dtype)
446
- target_ft2_tuple = (ft_target2, fast_ft_shape, complex_dtype)
447
-
448
- ret = {
449
- "template": template_tuple,
450
- "template_mask": template_mask_tuple,
451
- "ft_target": target_ft_tuple,
452
- "ft_target2": target_ft2_tuple,
453
- "targetshape": target.shape,
454
- "templateshape": template.shape,
455
- "fast_shape": fast_shape,
456
- "fast_ft_shape": fast_ft_shape,
457
- "real_dtype": real_dtype,
458
- "complex_dtype": complex_dtype,
459
- "callback_class": callback_class,
460
- "callback_class_args": callback_class_args,
461
- }
462
-
463
- return ret
464
-
465
-
466
- def flcSphericalMask_setup(
467
- rfftn: Callable,
468
- irfftn: Callable,
469
- template: NDArray,
470
- template_mask: NDArray,
471
- target: NDArray,
472
- fast_shape: Tuple[int],
473
- fast_ft_shape: Tuple[int],
474
- real_dtype: type,
475
- complex_dtype: type,
476
- shared_memory_handler: Callable,
477
- callback_class: Callable,
478
- callback_class_args: Dict,
479
- **kwargs,
480
- ) -> Dict:
481
- """
482
- Like :py:meth:`flc_setup` but for rotation invariant masks. In such cases
483
- the score can be computed quicker by not computing the fourier transforms
484
- of the mask for each rotation.
485
-
486
- .. math::
487
-
488
- \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
489
- {N_m * \\sqrt{
490
- \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
491
- }
492
-
493
- Where:
494
-
495
- .. math::
496
-
497
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
498
-
499
- and Nm is the number of voxels within the template mask m.
500
-
501
- References
502
- ----------
503
- .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
504
- Microsc. Microanal. 26, 2516 (2020)
505
- .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
506
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
507
-
508
- See Also
509
- --------
510
- :py:meth:`corr_scoring`
511
- """
512
- target_pad = backend.topleft_pad(target, fast_shape)
513
-
514
- # Target and squared target window sums
515
- ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
516
- ft_template_mask = backend.preallocate_array(fast_ft_shape, complex_dtype)
517
- ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
518
-
519
- temp = backend.preallocate_array(fast_shape, real_dtype)
520
- temp2 = backend.preallocate_array(fast_shape, real_dtype)
521
- numerator2 = backend.preallocate_array(1, real_dtype)
522
-
523
- eps = backend.eps(real_dtype)
524
- n_observations = backend.sum(template_mask)
525
-
526
- template_mask_pad = backend.topleft_pad(template_mask, fast_shape)
527
- rfftn(template_mask_pad, ft_template_mask)
528
-
529
- # Denominator E(X^2) - E(X)^2
530
- rfftn(backend.square(target_pad), ft_target)
531
- backend.multiply(ft_target, ft_template_mask, out=ft_temp)
532
- irfftn(ft_temp, temp2)
533
- backend.divide(temp2, n_observations, out=temp2)
534
-
535
- rfftn(target_pad, ft_target)
536
- backend.multiply(ft_target, ft_template_mask, out=ft_temp)
537
- irfftn(ft_temp, temp)
538
- backend.divide(temp, n_observations, out=temp)
539
- backend.square(temp, out=temp)
540
-
541
- backend.subtract(temp2, temp, out=temp)
542
-
543
- backend.maximum(temp, 0.0, out=temp)
544
- backend.sqrt(temp, out=temp)
545
- backend.multiply(temp, n_observations, out=temp)
546
-
547
- tol = 1e3 * eps * backend.max(backend.abs(temp))
548
- nonzero_indices = temp > tol
549
-
550
- backend.fill(temp2, 0)
551
- temp2[nonzero_indices] = 1 / temp[nonzero_indices]
552
-
553
- normalize_under_mask(
554
- template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
555
- )
556
-
557
- template_buffer = backend.arr_to_sharedarr(
558
- arr=template, shared_memory_handler=shared_memory_handler
559
- )
560
- template_mask_buffer = backend.arr_to_sharedarr(
561
- arr=template_mask, shared_memory_handler=shared_memory_handler
562
- )
563
- target_ft_buffer = backend.arr_to_sharedarr(
564
- arr=ft_target, shared_memory_handler=shared_memory_handler
565
- )
566
- inv_denominator_buffer = backend.arr_to_sharedarr(
567
- arr=temp2, shared_memory_handler=shared_memory_handler
568
- )
569
- numerator2_buffer = backend.arr_to_sharedarr(
570
- arr=numerator2, shared_memory_handler=shared_memory_handler
571
- )
572
-
573
- template_tuple = (template_buffer, template.shape, real_dtype)
574
- template_mask_tuple = (template_mask_buffer, template.shape, real_dtype)
575
- target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
576
-
577
- inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
578
- numerator2_tuple = (numerator2_buffer, (1,), real_dtype)
579
-
580
- ret = {
581
- "template": template_tuple,
582
- "template_mask": template_mask_tuple,
583
- "ft_target": target_ft_tuple,
584
- "inv_denominator": inv_denominator_tuple,
585
- "numerator2": numerator2_tuple,
586
- "targetshape": target.shape,
587
- "templateshape": template.shape,
588
- "fast_shape": fast_shape,
589
- "fast_ft_shape": fast_ft_shape,
590
- "real_dtype": real_dtype,
591
- "complex_dtype": complex_dtype,
592
- "callback_class": callback_class,
593
- "callback_class_args": callback_class_args,
594
- }
595
-
596
- return ret
597
-
598
-
599
- def mcc_setup(
600
- rfftn: Callable,
601
- irfftn: Callable,
602
- template: NDArray,
603
- template_mask: NDArray,
604
- target: NDArray,
605
- target_mask: NDArray,
606
- fast_shape: Tuple[int],
607
- fast_ft_shape: Tuple[int],
608
- real_dtype: type,
609
- complex_dtype: type,
610
- shared_memory_handler: Callable,
611
- callback_class: Callable,
612
- callback_class_args: Dict,
613
- **kwargs,
614
- ) -> Dict:
615
- """
616
- Setup to compute a normalized cross-correlation score with masks
617
- for both template and target.
618
-
619
- .. math::
620
-
621
- \\frac{
622
- CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
623
- }{
624
- \\sqrt{
625
- (CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
626
- (CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
627
- }
628
- }
629
-
630
- Where:
631
-
632
- .. math::
633
-
634
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
635
-
636
-
637
- References
638
- ----------
639
- .. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
640
-
641
- See Also
642
- --------
643
- :py:meth:`mcc_scoring`
644
- :py:class:`tme.matching_optimization.MaskedCrossCorrelation`
645
- """
646
- target = backend.multiply(target, target_mask > 0, out=target)
647
-
648
- target_pad = backend.topleft_pad(target, fast_shape)
649
- target_mask_pad = backend.topleft_pad(target_mask, fast_shape)
650
-
651
- target_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
652
- rfftn(target_pad, target_ft)
653
- target_ft_buffer = backend.arr_to_sharedarr(
654
- arr=target_ft, shared_memory_handler=shared_memory_handler
655
- )
656
-
657
- target_ft2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
658
- rfftn(backend.square(target_pad), target_ft2)
659
- target_ft2_buffer = backend.arr_to_sharedarr(
660
- arr=target_ft2, shared_memory_handler=shared_memory_handler
661
- )
31
+ last_type : type
32
+ The type of the last exception.
33
+ last_value :
34
+ The value of the last exception.
35
+ last_traceback : traceback
36
+ Traceback call stack at the point where the Exception occured.
662
37
 
663
- target_mask_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
664
- rfftn(target_mask_pad, target_mask_ft)
665
- target_mask_ft_buffer = backend.arr_to_sharedarr(
666
- arr=target_mask_ft, shared_memory_handler=shared_memory_handler
667
- )
668
-
669
- template_buffer = backend.arr_to_sharedarr(
670
- arr=template, shared_memory_handler=shared_memory_handler
671
- )
672
- template_mask_buffer = backend.arr_to_sharedarr(
673
- arr=template_mask, shared_memory_handler=shared_memory_handler
674
- )
675
-
676
- template_tuple = (template_buffer, template.shape, real_dtype)
677
- template_mask_tuple = (template_mask_buffer, template.shape, real_dtype)
678
-
679
- target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
680
- target_ft2_tuple = (target_ft2_buffer, fast_ft_shape, complex_dtype)
681
- target_mask_ft_tuple = (target_mask_ft_buffer, fast_ft_shape, complex_dtype)
682
-
683
- ret = {
684
- "template": template_tuple,
685
- "template_mask": template_mask_tuple,
686
- "ft_target": target_ft_tuple,
687
- "ft_target2": target_ft2_tuple,
688
- "ft_target_mask": target_mask_ft_tuple,
689
- "targetshape": target.shape,
690
- "templateshape": template.shape,
691
- "fast_shape": fast_shape,
692
- "fast_ft_shape": fast_ft_shape,
693
- "real_dtype": real_dtype,
694
- "complex_dtype": complex_dtype,
695
- "callback_class": callback_class,
696
- "callback_class_args": callback_class_args,
697
- }
698
-
699
- return ret
700
-
701
-
702
- def corr_scoring(
703
- template: Tuple[type, Tuple[int], type],
704
- ft_target: Tuple[type, Tuple[int], type],
705
- inv_denominator: Tuple[type, Tuple[int], type],
706
- numerator2: Tuple[type, Tuple[int], type],
707
- template_filter: Tuple[type, Tuple[int], type],
708
- targetshape: Tuple[int],
709
- templateshape: Tuple[int],
710
- fast_shape: Tuple[int],
711
- fast_ft_shape: Tuple[int],
712
- rotations: NDArray,
713
- real_dtype: type,
714
- complex_dtype: type,
715
- callback_class: CallbackClass,
716
- callback_class_args: Dict,
717
- interpolation_order: int,
718
- convolution_mode: str = "full",
719
- **kwargs,
720
- ) -> CallbackClass:
721
- """
722
- Calculates a normalized cross-correlation between a target f and a template g.
723
-
724
- .. math::
725
-
726
- (CC(f,g) - numerator2) \\cdot inv\\_denominator
727
-
728
- Parameters
729
- ----------
730
- template : Tuple[type, Tuple[int], type]
731
- Tuple containing a pointer to the template data, its shape, and its datatype.
732
- ft_target : Tuple[type, Tuple[int], type]
733
- Tuple containing a pointer to the fourier tranform of the target,
734
- its shape, and its datatype.
735
- inv_denominator : Tuple[type, Tuple[int], type]
736
- Tuple containing a pointer to the inverse denominator data, its shape, and its
737
- datatype.
738
- numerator2 : Tuple[type, Tuple[int], type]
739
- Tuple containing a pointer to the numerator2 data, its shape, and its datatype.
740
- fast_shape : Tuple[int]
741
- The shape for fast Fourier transform.
742
- fast_ft_shape : Tuple[int]
743
- The shape for fast Fourier transform of the target.
744
- rotations : NDArray
745
- Array containing the rotation matrices to be applied on the template.
746
- real_dtype : type
747
- Data type for the real part of the array.
748
- complex_dtype : type
749
- Data type for the complex part of the array.
750
- callback_class : CallbackClass
751
- A callable class or function for processing the results after each
752
- rotation.
753
- callback_class_args : Dict
754
- Dictionary of arguments to be passed to the callback class if it's
755
- instantiable.
756
- interpolation_order : int
757
- The order of interpolation to be used while rotating the template.
758
- **kwargs :
759
- Additional arguments to be passed to the function.
760
-
761
- Returns
762
- -------
763
- CallbackClass
764
- If callback_class was provided an instance of callback_class otherwise None.
765
-
766
- See Also
767
- --------
768
- :py:meth:`cc_setup`
769
- :py:meth:`corr_setup`
770
- :py:meth:`cam_setup`
771
- :py:meth:`flcSphericalMask_setup`
772
- """
773
- callback = callback_class
774
- if callback_class is not None and isinstance(callback_class, type):
775
- callback = callback_class(**callback_class_args)
776
-
777
- template_buffer, template_shape, template_dtype = template
778
- template = backend.sharedarr_to_arr(template_buffer, template_shape, template_dtype)
779
- ft_target = backend.sharedarr_to_arr(*ft_target)
780
- inv_denominator = backend.sharedarr_to_arr(*inv_denominator)
781
- numerator2 = backend.sharedarr_to_arr(*numerator2)
782
- template_filter = backend.sharedarr_to_arr(*template_filter)
783
-
784
- norm_template, template_mask, mask_sum = False, 1, 1
785
- if "template_mask" in kwargs:
786
- template_mask = backend.sharedarr_to_arr(
787
- kwargs["template_mask"][0], template_shape, template_dtype
788
- )
789
- norm_template, mask_sum = True, backend.sum(template_mask)
790
-
791
- arr = backend.preallocate_array(fast_shape, real_dtype)
792
- ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
793
-
794
- rfftn, irfftn = backend.build_fft(
795
- fast_shape=fast_shape,
796
- fast_ft_shape=fast_ft_shape,
797
- real_dtype=real_dtype,
798
- complex_dtype=complex_dtype,
799
- fftargs=kwargs.get("fftargs", {}),
800
- temp_real=arr,
801
- temp_fft=ft_temp,
802
- )
803
-
804
- norm_numerator = (backend.sum(numerator2) != 0) & (backend.size(numerator2) != 1)
805
- norm_denominator = (backend.sum(inv_denominator) != 1) & (
806
- backend.size(inv_denominator) != 1
807
- )
808
-
809
- norm_template = conditional_execute(normalize_under_mask, norm_template)
810
- callback_func = conditional_execute(callback, callback_class is not None)
811
- norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
812
- norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
813
- template_filter_func = conditional_execute(
814
- apply_filter, backend.size(template_filter) != 1
815
- )
816
-
817
- unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
818
- for index in range(rotations.shape[0]):
819
- rotation = rotations[index]
820
- backend.fill(arr, 0)
821
- backend.rotate_array(
822
- arr=template,
823
- rotation_matrix=rotation,
824
- out=arr,
825
- use_geometric_center=False,
826
- order=interpolation_order,
827
- )
828
-
829
- norm_template(arr[unpadded_slice], template_mask, mask_sum)
830
-
831
- rfftn(arr, ft_temp)
832
- template_filter_func(ft_template=ft_temp, template_filter=template_filter)
833
- backend.multiply(ft_target, ft_temp, out=ft_temp)
834
- irfftn(ft_temp, arr)
835
-
836
- norm_func_numerator(arr, numerator2, out=arr)
837
- norm_func_denominator(arr, inv_denominator, out=arr)
838
-
839
- callback_func(
840
- arr,
841
- rotation_matrix=rotation,
842
- rotation_index=index,
843
- **callback_class_args,
844
- )
845
-
846
- return callback
847
-
848
-
849
- def flc_scoring(
850
- template: Tuple[type, Tuple[int], type],
851
- template_mask: Tuple[type, Tuple[int], type],
852
- ft_target: Tuple[type, Tuple[int], type],
853
- ft_target2: Tuple[type, Tuple[int], type],
854
- template_filter: Tuple[type, Tuple[int], type],
855
- targetshape: Tuple[int],
856
- templateshape: Tuple[int],
857
- fast_shape: Tuple[int],
858
- fast_ft_shape: Tuple[int],
859
- rotations: NDArray,
860
- real_dtype: type,
861
- complex_dtype: type,
862
- callback_class: CallbackClass,
863
- callback_class_args: Dict,
864
- interpolation_order: int,
865
- **kwargs,
866
- ) -> CallbackClass:
867
- """
868
- Computes a normalized cross-correlation score of a target f a template g
869
- and a mask m:
870
-
871
- .. math::
872
-
873
- \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
874
- {N_m * \\sqrt{
875
- \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
876
- }
877
-
878
- Where:
879
-
880
- .. math::
881
-
882
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
883
-
884
- and Nm is the number of voxels within the template mask m.
885
-
886
- References
887
- ----------
888
- .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
889
- Microsc. Microanal. 26, 2516 (2020)
890
- .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
891
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
892
- """
893
- callback = callback_class
894
- if callback_class is not None and isinstance(callback_class, type):
895
- callback = callback_class(**callback_class_args)
896
-
897
- template = backend.sharedarr_to_arr(*template)
898
- template_mask = backend.sharedarr_to_arr(*template_mask)
899
- ft_target = backend.sharedarr_to_arr(*ft_target)
900
- ft_target2 = backend.sharedarr_to_arr(*ft_target2)
901
- template_filter = backend.sharedarr_to_arr(*template_filter)
902
-
903
- arr = backend.preallocate_array(fast_shape, real_dtype)
904
- temp = backend.preallocate_array(fast_shape, real_dtype)
905
- temp2 = backend.preallocate_array(fast_shape, real_dtype)
906
-
907
- ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
908
- ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
909
-
910
- rfftn, irfftn = backend.build_fft(
911
- fast_shape=fast_shape,
912
- fast_ft_shape=fast_ft_shape,
913
- real_dtype=real_dtype,
914
- complex_dtype=complex_dtype,
915
- fftargs=kwargs.get("fftargs", {}),
916
- temp_real=arr,
917
- temp_fft=ft_temp,
918
- )
919
- eps = backend.eps(real_dtype)
920
- template_filter_func = conditional_execute(
921
- apply_filter, backend.size(template_filter) != 1
922
- )
923
- callback_func = conditional_execute(callback, callback_class is not None)
924
-
925
- unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
926
- for index in range(rotations.shape[0]):
927
- rotation = rotations[index]
928
- backend.fill(arr, 0)
929
- backend.fill(temp, 0)
930
- backend.rotate_array(
931
- arr=template,
932
- arr_mask=template_mask,
933
- rotation_matrix=rotation,
934
- out=arr,
935
- out_mask=temp,
936
- use_geometric_center=False,
937
- order=interpolation_order,
938
- )
939
- # Given the amount of FFTs, might aswell normalize properly
940
- n_observations = backend.sum(temp)
941
-
942
- normalize_under_mask(
943
- template=arr[unpadded_slice],
944
- mask=temp[unpadded_slice],
945
- mask_intensity=n_observations,
946
- )
947
-
948
- rfftn(temp, ft_temp)
949
-
950
- backend.multiply(ft_target, ft_temp, out=ft_denom)
951
- irfftn(ft_denom, temp)
952
- backend.divide(temp, n_observations, out=temp)
953
- backend.square(temp, out=temp)
954
-
955
- backend.multiply(ft_target2, ft_temp, out=ft_denom)
956
- irfftn(ft_denom, temp2)
957
- backend.divide(temp2, n_observations, out=temp2)
958
-
959
- backend.subtract(temp2, temp, out=temp)
960
- backend.maximum(temp, 0.0, out=temp)
961
- backend.sqrt(temp, out=temp)
962
- backend.multiply(temp, n_observations, out=temp)
963
-
964
- rfftn(arr, ft_temp)
965
- template_filter_func(ft_template=ft_temp, template_filter=template_filter)
966
- backend.multiply(ft_target, ft_temp, out=ft_temp)
967
- irfftn(ft_temp, arr)
968
-
969
- tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
970
- nonzero_indices = temp > tol
971
- backend.fill(temp2, 0)
972
- temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
973
-
974
- callback_func(
975
- temp2,
976
- rotation_matrix=rotation,
977
- rotation_index=index,
978
- **callback_class_args,
979
- )
980
-
981
- return callback
982
-
983
-
984
- def flc_scoring2(
985
- template: Tuple[type, Tuple[int], type],
986
- template_mask: Tuple[type, Tuple[int], type],
987
- ft_target: Tuple[type, Tuple[int], type],
988
- ft_target2: Tuple[type, Tuple[int], type],
989
- template_filter: Tuple[type, Tuple[int], type],
990
- targetshape: Tuple[int],
991
- templateshape: Tuple[int],
992
- fast_shape: Tuple[int],
993
- fast_ft_shape: Tuple[int],
994
- rotations: NDArray,
995
- real_dtype: type,
996
- complex_dtype: type,
997
- callback_class: CallbackClass,
998
- callback_class_args: Dict,
999
- interpolation_order: int,
1000
- **kwargs,
1001
- ) -> CallbackClass:
1002
- """
1003
- Computes a normalized cross-correlation score of a target f a template g
1004
- and a mask m:
1005
-
1006
- .. math::
1007
-
1008
- \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
1009
- {N_m * \\sqrt{
1010
- \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
1011
- }
1012
-
1013
- Where:
1014
-
1015
- .. math::
1016
-
1017
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
1018
-
1019
- and Nm is the number of voxels within the template mask m.
1020
-
1021
- References
1022
- ----------
1023
- .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
1024
- Microsc. Microanal. 26, 2516 (2020)
1025
- .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
1026
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
1027
- """
1028
- callback = callback_class
1029
- if callback_class is not None and isinstance(callback_class, type):
1030
- callback = callback_class(**callback_class_args)
1031
-
1032
- # Retrieve objects from shared memory
1033
- template = backend.sharedarr_to_arr(*template)
1034
- template_mask = backend.sharedarr_to_arr(*template_mask)
1035
- ft_target = backend.sharedarr_to_arr(*ft_target)
1036
- ft_target2 = backend.sharedarr_to_arr(*ft_target2)
1037
- template_filter = backend.sharedarr_to_arr(*template_filter)
1038
-
1039
- arr = backend.preallocate_array(fast_shape, real_dtype)
1040
- temp = backend.preallocate_array(fast_shape, real_dtype)
1041
- temp2 = backend.preallocate_array(fast_shape, real_dtype)
1042
-
1043
- ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
1044
- ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
1045
-
1046
- eps = backend.eps(real_dtype)
1047
- template_filter_func = conditional_execute(
1048
- apply_filter, backend.size(template_filter) != 1
1049
- )
1050
- callback_func = conditional_execute(callback, callback_class is not None)
1051
-
1052
- squeeze_axis = tuple(i for i, x in enumerate(template.shape) if x == 1)
1053
- squeeze = tuple(
1054
- slice(0, stop) if i not in squeeze_axis else 0
1055
- for i, stop in enumerate(template.shape)
1056
- )
1057
- squeeze_fast = tuple(
1058
- slice(0, stop) if i not in squeeze_axis else 0
1059
- for i, stop in enumerate(fast_shape)
1060
- )
1061
- squeeze_fast_ft = tuple(
1062
- slice(0, stop) if i not in squeeze_axis else 0
1063
- for i, stop in enumerate(fast_ft_shape)
1064
- )
1065
-
1066
- rfftn, irfftn = backend.build_fft(
1067
- fast_shape=temp[squeeze_fast].shape,
1068
- fast_ft_shape=fast_ft_shape,
1069
- real_dtype=real_dtype,
1070
- complex_dtype=complex_dtype,
1071
- fftargs=kwargs.get("fftargs", {}),
1072
- inverse_fast_shape=fast_shape,
1073
- temp_real=arr[squeeze_fast],
1074
- temp_fft=ft_temp,
1075
- )
1076
- for index in range(rotations.shape[0]):
1077
- rotation = rotations[index]
1078
- backend.fill(arr, 0)
1079
- backend.fill(temp, 0)
1080
- backend.rotate_array(
1081
- arr=template[squeeze],
1082
- arr_mask=template_mask[squeeze],
1083
- rotation_matrix=rotation,
1084
- out=arr[squeeze],
1085
- out_mask=temp[squeeze],
1086
- use_geometric_center=False,
1087
- order=interpolation_order,
1088
- )
1089
- # Given the amount of FFTs, might aswell normalize properly
1090
- n_observations = backend.sum(temp)
1091
-
1092
- normalize_under_mask(
1093
- template=arr[squeeze],
1094
- mask=temp[squeeze],
1095
- mask_intensity=n_observations,
1096
- )
1097
- rfftn(temp[squeeze_fast], ft_temp[squeeze_fast_ft])
1098
-
1099
- backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
1100
- irfftn(ft_denom, temp)
1101
- backend.divide(temp, n_observations, out=temp)
1102
- backend.square(temp, out=temp)
1103
-
1104
- backend.multiply(ft_target2, ft_temp[squeeze_fast_ft], out=ft_denom)
1105
- irfftn(ft_denom, temp2)
1106
- backend.divide(temp2, n_observations, out=temp2)
1107
-
1108
- backend.subtract(temp2, temp, out=temp)
1109
- backend.maximum(temp, 0.0, out=temp)
1110
- backend.sqrt(temp, out=temp)
1111
- backend.multiply(temp, n_observations, out=temp)
1112
-
1113
- rfftn(arr[squeeze_fast], ft_temp[squeeze_fast_ft])
1114
- template_filter_func(ft_template=ft_temp, template_filter=template_filter)
1115
- backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
1116
- irfftn(ft_denom, arr)
1117
-
1118
- tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
1119
- nonzero_indices = temp > tol
1120
- backend.fill(temp2, 0)
1121
- temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
1122
-
1123
- callback_func(
1124
- temp2,
1125
- rotation_matrix=rotation,
1126
- rotation_index=index,
1127
- **callback_class_args,
1128
- )
1129
- return callback
1130
-
1131
-
1132
- def mcc_scoring(
1133
- template: Tuple[type, Tuple[int], type],
1134
- template_mask: Tuple[type, Tuple[int], type],
1135
- ft_target: Tuple[type, Tuple[int], type],
1136
- ft_target2: Tuple[type, Tuple[int], type],
1137
- ft_target_mask: Tuple[type, Tuple[int], type],
1138
- template_filter: Tuple[type, Tuple[int], type],
1139
- targetshape: Tuple[int],
1140
- templateshape: Tuple[int],
1141
- fast_shape: Tuple[int],
1142
- fast_ft_shape: Tuple[int],
1143
- rotations: NDArray,
1144
- real_dtype: type,
1145
- complex_dtype: type,
1146
- callback_class: CallbackClass,
1147
- callback_class_args: type,
1148
- interpolation_order: int,
1149
- overlap_ratio: float = 0.3,
1150
- **kwargs,
1151
- ) -> CallbackClass:
1152
- """
1153
- Computes a cross-correlation score with masks for both template and target.
1154
-
1155
- .. math::
1156
-
1157
- \\frac{
1158
- CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
1159
- }{
1160
- \\sqrt{
1161
- (CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
1162
- (CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
1163
- }
1164
- }
1165
-
1166
- Where:
1167
-
1168
- .. math::
1169
-
1170
- CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
1171
-
1172
-
1173
- References
1174
- ----------
1175
- .. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
1176
- .. [2] https://scikit-image.org/docs/stable/api/skimage.registration.html
1177
-
1178
- See Also
1179
- --------
1180
- :py:class:`tme.matching_optimization.MaskedCrossCorrelation`
38
+ Raises
39
+ ------
40
+ Exception
41
+ Re-raises the last exception.
1181
42
  """
1182
- callback = callback_class
1183
- if callback_class is not None and isinstance(callback_class, type):
1184
- callback = callback_class(**callback_class_args)
1185
-
1186
- # Retrieve objects from shared memory
1187
- template_buffer, template_shape, template_dtype = template
1188
- template = backend.sharedarr_to_arr(*template)
1189
- target_ft = backend.sharedarr_to_arr(*ft_target)
1190
- target_ft2 = backend.sharedarr_to_arr(*ft_target2)
1191
- template_mask = backend.sharedarr_to_arr(*template_mask)
1192
- target_mask_ft = backend.sharedarr_to_arr(*ft_target_mask)
1193
- template_filter = backend.sharedarr_to_arr(*template_filter)
1194
-
1195
- axes = tuple(range(template.ndim))
1196
- eps = backend.eps(real_dtype)
1197
-
1198
- # Allocate score and process specific arrays
1199
- template_rot = backend.preallocate_array(fast_shape, real_dtype)
1200
- mask_overlap = backend.preallocate_array(fast_shape, real_dtype)
1201
- numerator = backend.preallocate_array(fast_shape, real_dtype)
1202
-
1203
- temp = backend.preallocate_array(fast_shape, real_dtype)
1204
- temp2 = backend.preallocate_array(fast_shape, real_dtype)
1205
- temp3 = backend.preallocate_array(fast_shape, real_dtype)
1206
- temp_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
1207
-
1208
- rfftn, irfftn = backend.build_fft(
1209
- fast_shape=fast_shape,
1210
- fast_ft_shape=fast_ft_shape,
1211
- real_dtype=real_dtype,
1212
- complex_dtype=complex_dtype,
1213
- fftargs=kwargs.get("fftargs", {}),
1214
- temp_real=numerator,
1215
- temp_fft=temp_ft,
1216
- )
1217
-
1218
- template_filter_func = conditional_execute(
1219
- apply_filter, backend.size(template_filter) != 1
1220
- )
1221
- callback_func = conditional_execute(callback, callback_class is not None)
1222
-
1223
- # Calculate scores across all rotations
1224
- for index in range(rotations.shape[0]):
1225
- rotation = rotations[index]
1226
- backend.fill(template_rot, 0)
1227
- backend.fill(temp, 0)
1228
-
1229
- backend.rotate_array(
1230
- arr=template,
1231
- arr_mask=template_mask,
1232
- rotation_matrix=rotation,
1233
- out=template_rot,
1234
- out_mask=temp,
1235
- use_geometric_center=False,
1236
- order=interpolation_order,
1237
- )
1238
-
1239
- backend.multiply(template_rot, temp > 0, out=template_rot)
1240
-
1241
- # template_rot_ft
1242
- rfftn(template_rot, temp_ft)
1243
- template_filter_func(ft_template=temp_ft, template_filter=template_filter)
1244
- irfftn(target_mask_ft * temp_ft, temp2)
1245
- irfftn(target_ft * temp_ft, numerator)
1246
-
1247
- # temp template_mask_rot | temp_ft template_mask_rot_ft
1248
- # Calculate overlap of masks at every point in the convolution.
1249
- # Locations with high overlap should not be taken into account.
1250
- rfftn(temp, temp_ft)
1251
- irfftn(temp_ft * target_mask_ft, mask_overlap)
1252
- mask_overlap[:] = np.round(mask_overlap)
1253
- mask_overlap[:] = np.maximum(mask_overlap, eps)
1254
- irfftn(temp_ft * target_ft, temp)
1255
-
1256
- backend.subtract(
1257
- numerator,
1258
- backend.divide(backend.multiply(temp, temp2), mask_overlap),
1259
- out=numerator,
1260
- )
1261
-
1262
- # temp_3 = fixed_denom
1263
- backend.multiply(temp_ft, target_ft2, out=temp_ft)
1264
- irfftn(temp_ft, temp3)
1265
- backend.subtract(
1266
- temp3, backend.divide(backend.square(temp), mask_overlap), out=temp3
1267
- )
1268
- backend.maximum(temp3, 0.0, out=temp3)
1269
-
1270
- # temp = moving_denom
1271
- rfftn(backend.square(template_rot), temp_ft)
1272
- backend.multiply(target_mask_ft, temp_ft, out=temp_ft)
1273
- irfftn(temp_ft, temp)
1274
-
1275
- backend.subtract(
1276
- temp, backend.divide(backend.square(temp2), mask_overlap), out=temp
1277
- )
1278
- backend.maximum(temp, 0.0, out=temp)
1279
-
1280
- # temp_2 = denom
1281
- backend.multiply(temp3, temp, out=temp)
1282
- backend.sqrt(temp, temp2)
43
+ if last_type is None:
44
+ return None
45
+ traceback.print_tb(last_traceback)
46
+ raise Exception(last_value)
1283
47
 
1284
- # Pixels where `denom` is very small will introduce large
1285
- # numbers after division. To get around this problem,
1286
- # we zero-out problematic pixels.
1287
- tol = 1e3 * eps * backend.max(backend.abs(temp2), axis=axes, keepdims=True)
1288
- nonzero_indices = temp2 > tol
1289
48
 
1290
- backend.fill(temp, 0)
1291
- temp[nonzero_indices] = numerator[nonzero_indices] / temp2[nonzero_indices]
1292
- backend.clip(temp, a_min=-1, a_max=1, out=temp)
1293
-
1294
- # Apply overlap ratio threshold
1295
- number_px_threshold = overlap_ratio * backend.max(
1296
- mask_overlap, axis=axes, keepdims=True
1297
- )
1298
- temp[mask_overlap < number_px_threshold] = 0.0
49
+ def _wrap_backend(func):
50
+ @wraps(func)
51
+ def wrapper(*args, backend_name: str, backend_args: Dict, **kwargs):
52
+ from tme.backends import backend as be
1299
53
 
1300
- callback_func(
1301
- temp,
1302
- rotation_matrix=rotation,
1303
- rotation_index=index,
1304
- **callback_class_args,
1305
- )
54
+ be.change_backend(backend_name, **backend_args)
55
+ return func(*args, **kwargs)
1306
56
 
1307
- return callback
57
+ return wrapper
1308
58
 
1309
59
 
1310
- def _setup_template_filter(
60
+ def _setup_template_filter_apply_target_filter(
1311
61
  matching_data: MatchingData,
1312
62
  rfftn: Callable,
1313
63
  irfftn: Callable,
1314
64
  fast_shape: Tuple[int],
1315
65
  fast_ft_shape: Tuple[int],
66
+ pad_template_filter: bool = True,
1316
67
  ):
1317
68
  filter_template = isinstance(matching_data.template_filter, Compose)
1318
69
  filter_target = isinstance(matching_data.target_filter, Compose)
1319
70
 
1320
- template_filter = backend.full(
1321
- shape=(1,), fill_value=1, dtype=backend._default_dtype
1322
- )
71
+ template_filter = be.full(shape=(1,), fill_value=1, dtype=be._float_dtype)
1323
72
 
1324
73
  if not filter_template and not filter_target:
1325
74
  return template_filter
1326
75
 
1327
- target_temp = backend.astype(
1328
- backend.topleft_pad(matching_data.target, fast_shape), backend._default_dtype
1329
- )
1330
- target_temp_ft = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
1331
- rfftn(target_temp, target_temp_ft)
76
+ target_temp = be.topleft_pad(matching_data.target, fast_shape)
77
+ target_temp_ft = be.zeros(fast_ft_shape, be._complex_dtype)
78
+
79
+ inv_mask = be.subtract(1, be.to_backend_array(matching_data._batch_mask))
80
+ filter_shape = be.multiply(be.to_backend_array(fast_ft_shape), inv_mask)
81
+ filter_shape = tuple(int(x) if x != 0 else 1 for x in filter_shape)
82
+
83
+ fast_shape = be.multiply(be.to_backend_array(fast_shape), inv_mask)
84
+ fast_shape = tuple(int(x) for x in fast_shape if x != 0)
85
+
86
+ target_temp_ft = rfftn(target_temp, target_temp_ft)
87
+ if filter_template:
88
+ # TODO: Pad to fast shapes and adapt _setup_template_filtering accordingly
89
+ template_fast_shape, template_filter_shape = fast_shape, filter_shape
90
+ if not pad_template_filter:
91
+ template_fast_shape = tuple(int(x) for x in matching_data._template.shape)
92
+ template_filter_shape = [
93
+ int(x) for x in matching_data._output_template_shape
94
+ ]
95
+ template_filter_shape[-1] = template_filter_shape[-1] // 2 + 1
1332
96
 
1333
- if isinstance(matching_data.template_filter, Compose):
1334
97
  template_filter = matching_data.template_filter(
1335
- shape=fast_shape,
98
+ shape=template_fast_shape,
1336
99
  return_real_fourier=True,
1337
100
  shape_is_real_fourier=False,
1338
101
  data_rfft=target_temp_ft,
1339
- )
1340
- template_filter = template_filter["data"]
1341
- template_filter[tuple(0 for _ in range(template_filter.ndim))] = 0
102
+ batch_dimension=matching_data._target_dims,
103
+ )["data"]
104
+ template_filter = be.reshape(template_filter, template_filter_shape)
1342
105
 
1343
- if isinstance(matching_data.target_filter, Compose):
106
+ if filter_target:
1344
107
  target_filter = matching_data.target_filter(
1345
108
  shape=fast_shape,
1346
109
  return_real_fourier=True,
1347
110
  shape_is_real_fourier=False,
1348
111
  data_rfft=target_temp_ft,
1349
112
  weight_type=None,
1350
- )
1351
- target_filter = target_filter["data"]
1352
- backend.multiply(target_temp_ft, target_filter, out=target_temp_ft)
113
+ batch_dimension=matching_data._target_dims,
114
+ )["data"]
115
+ target_filter = be.reshape(target_filter, filter_shape)
116
+ target_temp_ft = be.multiply(target_temp_ft, target_filter, out=target_temp_ft)
1353
117
 
1354
- irfftn(target_temp_ft, target_temp)
1355
- matching_data._target = backend.topleft_pad(
1356
- target_temp, matching_data.target.shape
1357
- )
118
+ target_temp = irfftn(target_temp_ft, target_temp)
119
+ matching_data._target = be.topleft_pad(target_temp, matching_data.target.shape)
1358
120
 
1359
121
  return template_filter
1360
122
 
@@ -1368,13 +130,14 @@ def device_memory_handler(func: Callable):
1368
130
  last_type, last_value, last_traceback = sys.exc_info()
1369
131
  try:
1370
132
  with SharedMemoryManager() as smh:
1371
- with backend.set_device(kwargs.get("gpu_index", 0)):
133
+ gpu_index = kwargs.pop("gpu_index") if "gpu_index" in kwargs else 0
134
+ with be.set_device(gpu_index):
1372
135
  return_value = func(shared_memory_handler=smh, *args, **kwargs)
1373
136
  except Exception as e:
1374
137
  print(e)
1375
138
  last_type, last_value, last_traceback = sys.exc_info()
1376
139
  finally:
1377
- handle_traceback(last_type, last_value, last_traceback)
140
+ _handle_traceback(last_type, last_value, last_traceback)
1378
141
  return return_value
1379
142
 
1380
143
  return inner_function
@@ -1388,19 +151,20 @@ def scan(
1388
151
  n_jobs: int = 4,
1389
152
  callback_class: CallbackClass = None,
1390
153
  callback_class_args: Dict = {},
1391
- fftargs: Dict = {},
1392
154
  pad_fourier: bool = True,
155
+ pad_template_filter: bool = True,
1393
156
  interpolation_order: int = 3,
1394
157
  jobs_per_callback_class: int = 8,
1395
- **kwargs,
1396
- ) -> Tuple:
158
+ shared_memory_handler=None,
159
+ ) -> Optional[Tuple]:
1397
160
  """
1398
- Perform template matching between target and template and sample
1399
- different rotations of template.
161
+ Run template matching.
162
+
163
+ .. warning:: ``matching_data`` might be altered or destroyed during computation.
1400
164
 
1401
165
  Parameters
1402
166
  ----------
1403
- matching_data : MatchingData
167
+ matching_data : :py:class:`tme.matching_data.MatchingData`
1404
168
  Template matching data.
1405
169
  matching_setup : Callable
1406
170
  Function pointer to setup function.
@@ -1412,186 +176,131 @@ def scan(
1412
176
  Analyzer class pointer to operate on computed scores.
1413
177
  callback_class_args : dict, optional
1414
178
  Arguments passed to the callback_class. Default is an empty dictionary.
1415
- fftargs : dict, optional
1416
- Arguments for the FFT operations. Default is an empty dictionary.
1417
179
  pad_fourier: bool, optional
1418
180
  Whether to pad target and template to the full convolution shape.
181
+ pad_template_filter: bool, optional
182
+ Whether to pad potential template filters to the full convolution shape.
1419
183
  interpolation_order : int, optional
1420
184
  Order of spline interpolation for rotations.
1421
185
  jobs_per_callback_class : int, optional
1422
186
  How many jobs should be processed by a single callback_class instance,
1423
- if ones is provided.
1424
- **kwargs : various
1425
- Additional arguments.
187
+ if one is provided.
188
+ shared_memory_handler : type, optional
189
+ Manager for shared memory objects, None by default.
1426
190
 
1427
191
  Returns
1428
192
  -------
1429
- Tuple
193
+ Optional[Tuple]
1430
194
  The merged results from callback_class if provided otherwise None.
195
+
196
+ Examples
197
+ --------
198
+ Schematically, using :py:meth:`scan` is similar to :py:meth:`scan_subsets`,
199
+ with the distinction that the objects contained in ``matching_data`` are not
200
+ split and the search is only parallelized over angles.
201
+ Assuming you have followed the example in :py:meth:`scan_subsets`, :py:meth:`scan`
202
+ can be invoked like so
203
+
204
+ >>> from tme.matching_exhaustive import scan
205
+ >>> results = scan(
206
+ >>> matching_data = matching_data,
207
+ >>> matching_score = matching_score,
208
+ >>> matching_setup = matching_setup,
209
+ >>> callback_class = callback_class,
210
+ >>> callback_class_args = callback_class_args,
211
+ >>> )
212
+
1431
213
  """
1432
214
  matching_data.to_backend()
1433
- shape_diff = backend.subtract(
1434
- matching_data._output_target_shape, matching_data._output_template_shape
1435
- )
1436
- shape_diff = backend.multiply(shape_diff, ~matching_data._batch_mask)
1437
- if backend.sum(shape_diff < 0) and not pad_fourier:
1438
- warnings.warn(
1439
- "Target is larger than template and Fourier padding is turned off."
1440
- " This can lead to shifted results. You can swap template and target, or"
1441
- " zero-pad the target."
1442
- )
1443
215
  fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
1444
216
  pad_fourier=pad_fourier
1445
217
  )
1446
218
 
1447
- callback_class_args["fourier_shift"] = fourier_shift
1448
- rfftn, irfftn = backend.build_fft(
219
+ rfftn, irfftn = be.build_fft(
1449
220
  fast_shape=fast_shape,
1450
221
  fast_ft_shape=fast_ft_shape,
1451
- real_dtype=matching_data._default_dtype,
1452
- complex_dtype=matching_data._complex_dtype,
1453
- fftargs=fftargs,
222
+ real_dtype=be._float_dtype,
223
+ complex_dtype=be._complex_dtype,
1454
224
  )
1455
-
1456
- # template_filter = _setup_template_filter(
1457
- # matching_data=matching_data,
1458
- # rfftn=rfftn,
1459
- # irfftn=irfftn,
1460
- # fast_shape=fast_shape,
1461
- # fast_ft_shape=fast_ft_shape,
1462
- # )
225
+ template_filter = _setup_template_filter_apply_target_filter(
226
+ matching_data=matching_data,
227
+ rfftn=rfftn,
228
+ irfftn=irfftn,
229
+ fast_shape=fast_shape,
230
+ fast_ft_shape=fast_ft_shape,
231
+ pad_template_filter=pad_template_filter,
232
+ )
233
+ template_filter = be.astype(be.to_backend_array(template_filter), be._float_dtype)
1463
234
 
1464
235
  setup = matching_setup(
1465
236
  rfftn=rfftn,
1466
237
  irfftn=irfftn,
1467
238
  template=matching_data.template,
239
+ template_filter=template_filter,
1468
240
  template_mask=matching_data.template_mask,
1469
241
  target=matching_data.target,
1470
242
  target_mask=matching_data.target_mask,
1471
243
  fast_shape=fast_shape,
1472
244
  fast_ft_shape=fast_ft_shape,
1473
- real_dtype=matching_data._default_dtype,
1474
- complex_dtype=matching_data._complex_dtype,
1475
- callback_class=callback_class,
1476
- callback_class_args=callback_class_args,
1477
- **kwargs,
245
+ shared_memory_handler=shared_memory_handler,
1478
246
  )
1479
247
  rfftn, irfftn = None, None
1480
-
1481
-
1482
- template_filter, preprocessor = None, Preprocessor()
1483
- for method, parameters in matching_data.template_filter.items():
1484
- parameters["shape"] = fast_shape
1485
- parameters["omit_negative_frequencies"] = True
1486
- out = preprocessor.apply_method(method=method, parameters=parameters)
1487
- if template_filter is None:
1488
- template_filter = out
1489
- np.multiply(template_filter, out, out=template_filter)
1490
-
1491
- if template_filter is None:
1492
- template_filter = backend.full(
1493
- shape=(1,), fill_value=1, dtype=backend._default_dtype
1494
- )
1495
-
1496
- template_filter = backend.to_backend_array(template_filter)
1497
- template_filter = backend.astype(template_filter, backend._default_dtype)
1498
- template_filter_buffer = backend.arr_to_sharedarr(
1499
- arr=template_filter,
1500
- shared_memory_handler=kwargs.get("shared_memory_handler", None),
1501
- )
1502
- setup["template_filter"] = (
1503
- template_filter_buffer,
1504
- template_filter.shape,
1505
- template_filter.dtype,
1506
- )
1507
-
1508
- callback_class_args["translation_offset"] = backend.astype(
1509
- matching_data._translation_offset, int
1510
- )
1511
- callback_class_args["thread_safe"] = n_jobs > 1
1512
- callback_class_args["gpu_index"] = kwargs.get("gpu_index", -1)
1513
-
1514
- n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
1515
- callback_class = setup.pop("callback_class", callback_class)
1516
- callback_class_args = setup.pop("callback_class_args", callback_class_args)
1517
- callback_classes = [callback_class for _ in range(n_callback_classes)]
1518
-
1519
- convolution_mode = "same"
1520
- if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
1521
- convolution_mode = "valid"
1522
-
1523
-
1524
- callback_class_args["fourier_shift"] = fourier_shift
1525
- callback_class_args["convolution_mode"] = convolution_mode
1526
- callback_class_args["targetshape"] = setup["targetshape"]
1527
- callback_class_args["templateshape"] = setup["templateshape"]
1528
-
1529
- if callback_class == MaxScoreOverRotations:
1530
- callback_classes = [
1531
- class_name(
1532
- score_space_shape=fast_shape,
1533
- score_space_dtype=matching_data._default_dtype,
1534
- shared_memory_handler=kwargs.get("shared_memory_handler", None),
1535
- rotation_space_dtype=backend._default_dtype_int,
1536
- **callback_class_args,
1537
- )
1538
- for class_name in callback_classes
1539
- ]
1540
-
1541
- matching_data._target, matching_data._template = None, None
1542
- matching_data._target_mask, matching_data._template_mask = None, None
1543
-
1544
- setup["fftargs"] = fftargs.copy()
1545
- convolution_mode = "same"
1546
- if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
1547
- convolution_mode = "valid"
1548
- setup["convolution_mode"] = convolution_mode
1549
248
  setup["interpolation_order"] = interpolation_order
1550
- rotation_list = matching_data._split_rotations_on_jobs(n_jobs)
1551
-
1552
- backend.free_cache()
1553
-
1554
- def _run_scoring(backend_name, backend_args, rotations, **kwargs):
1555
- from tme.backends import backend
249
+ setup["template_filter"] = be.to_sharedarr(template_filter, shared_memory_handler)
250
+
251
+ offset = be.to_backend_array(matching_data._translation_offset)
252
+ convmode = "valid" if getattr(matching_data, "_is_padded", False) else "same"
253
+ default_callback_args = {
254
+ "offset": be.astype(offset, be._int_dtype),
255
+ "thread_safe": n_jobs > 1,
256
+ "fourier_shift": fourier_shift,
257
+ "convolution_mode": convmode,
258
+ "targetshape": matching_data.target.shape,
259
+ "templateshape": matching_data.template.shape,
260
+ "fast_shape": fast_shape,
261
+ "indices": getattr(matching_data, "indices", None),
262
+ "shared_memory_handler": shared_memory_handler,
263
+ "only_unique_rotations": True,
264
+ }
265
+ default_callback_args.update(callback_class_args)
1556
266
 
1557
- backend.change_backend(backend_name, **backend_args)
1558
- return matching_score(rotations=rotations, **kwargs)
267
+ matching_data._free_data()
268
+ be.free_cache()
1559
269
 
270
+ # For highly parallel jobs, blocking in certain analyzers becomes a bottleneck
271
+ if getattr(callback_class, "shared", True):
272
+ jobs_per_callback_class = 1
273
+ n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
274
+ callback_classes = [
275
+ callback_class(
276
+ shape=fast_shape,
277
+ **default_callback_args,
278
+ )
279
+ if callback_class is not None
280
+ else None
281
+ for _ in range(n_callback_classes)
282
+ ]
1560
283
  callbacks = Parallel(n_jobs=n_jobs)(
1561
- delayed(_run_scoring)(
1562
- backend_name=backend._backend_name,
1563
- backend_args=backend._backend_args,
284
+ delayed(_wrap_backend(matching_score))(
285
+ backend_name=be._backend_name,
286
+ backend_args=be._backend_args,
1564
287
  rotations=rotation,
1565
- callback_class=callback_classes[index % n_callback_classes],
1566
- callback_class_args=callback_class_args,
288
+ callback=callback_classes[index % n_callback_classes],
1567
289
  **setup,
1568
290
  )
1569
- for index, rotation in enumerate(rotation_list)
291
+ for index, rotation in enumerate(matching_data._split_rotations_on_jobs(n_jobs))
1570
292
  )
1571
293
 
1572
- callbacks = callbacks[0:n_callback_classes]
1573
294
  callbacks = [
1574
- tuple(callback._postprocess(
1575
- fourier_shift = fourier_shift,
1576
- convolution_mode = convolution_mode,
1577
- targetshape = setup["targetshape"],
1578
- templateshape = setup["templateshape"],
1579
- shared_memory_handler=kwargs.get("shared_memory_handler", None)
1580
- )) if hasattr(callback, "_postprocess") else tuple(callback)
1581
- for callback in callbacks if callback is not None
295
+ tuple(callback._postprocess(**default_callback_args))
296
+ for callback in callbacks
297
+ if callback is not None
1582
298
  ]
1583
- backend.free_cache()
299
+ be.free_cache()
1584
300
 
1585
- merged_callback = None
1586
301
  if callback_class is not None:
1587
- merged_callback = callback_class.merge(
1588
- callbacks,
1589
- **callback_class_args,
1590
- score_indices=matching_data.indices,
1591
- inner_merge=True,
1592
- )
1593
-
1594
- return merged_callback
302
+ return callback_class.merge(callbacks, **default_callback_args)
303
+ return None
1595
304
 
1596
305
 
1597
306
  def scan_subsets(
@@ -1605,19 +314,21 @@ def scan_subsets(
1605
314
  template_splits: Dict = {},
1606
315
  pad_target_edges: bool = False,
1607
316
  pad_fourier: bool = True,
317
+ pad_template_filter: bool = True,
1608
318
  interpolation_order: int = 3,
1609
319
  jobs_per_callback_class: int = 8,
1610
- **kwargs,
1611
- ) -> Tuple:
320
+ backend_name: str = None,
321
+ backend_args: Dict = {},
322
+ ) -> Optional[Tuple]:
1612
323
  """
1613
- Wrapper around :py:meth:`scan` that supports template matching on splits
1614
- of template and target.
324
+ Wrapper around :py:meth:`scan` that supports matching on splits
325
+ of ``matching_data``.
1615
326
 
1616
327
  Parameters
1617
328
  ----------
1618
- matching_data : MatchingData
1619
- Template matching data.
1620
- matching_func : type
329
+ matching_data : :py:class:`tme.matching_data.MatchingData`
330
+ MatchingData instance containing relevant data.
331
+ matching_setup : type
1621
332
  Function pointer to setup function.
1622
333
  matching_score : type
1623
334
  Function pointer to scoring function.
@@ -1626,88 +337,138 @@ def scan_subsets(
1626
337
  callback_class_args : dict, optional
1627
338
  Arguments passed to the callback_class. Default is an empty dictionary.
1628
339
  job_schedule : tuple of int, optional
1629
- Schedule of jobs. Default is (1, 1).
340
+ Job scheduling scheme, default is (1, 1). First value corresponds
341
+ to the number of splits that are processed in parallel, the second
342
+ to the number of angles evaluated in parallel on each split.
1630
343
  target_splits : dict, optional
1631
- Splits for target. Default is an empty dictionary, i.e. no splits
344
+ Splits for target. Default is an empty dictionary, i.e. no splits.
345
+ See :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
1632
346
  template_splits : dict, optional
1633
347
  Splits for template. Default is an empty dictionary, i.e. no splits.
348
+ See :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
1634
349
  pad_target_edges : bool, optional
1635
350
  Whether to pad the target boundaries by half the template shape
1636
351
  along each axis.
1637
352
  pad_fourier: bool, optional
1638
353
  Whether to pad target and template to the full convolution shape.
354
+ pad_template_filter: bool, optional
355
+ Whether to pad potential template filters to the full convolution shape.
1639
356
  interpolation_order : int, optional
1640
357
  Order of spline interpolation for rotations.
1641
358
  jobs_per_callback_class : int, optional
1642
359
  How many jobs should be processed by a single callback_class instance,
1643
360
  if ones is provided.
1644
- **kwargs : various
1645
- Additional arguments.
1646
-
1647
- Notes
1648
- -----
1649
- Objects in matching_data might be destroyed during computation.
1650
361
 
1651
362
  Returns
1652
363
  -------
1653
- Tuple
364
+ Optional[Tuple]
1654
365
  The merged results from callback_class if provided otherwise None.
366
+
367
+ Examples
368
+ --------
369
+ All data relevant to template matching will be contained in ``matching_data``, which
370
+ is a :py:class:`tme.matching_data.MatchingData` instance and can be created like so
371
+
372
+ >>> import numpy as np
373
+ >>> from tme.matching_data import MatchingData
374
+ >>> from tme.matching_utils import get_rotation_matrices
375
+ >>> target = np.random.rand(50,40,60)
376
+ >>> template = target[15:25, 10:20, 30:40]
377
+ >>> matching_data = MatchingData(target, template)
378
+ >>> matching_data.rotations = get_rotation_matrices(
379
+ >>> angular_sampling = 60, dim = target.ndim
380
+ >>> )
381
+
382
+ The template matching procedure is determined by ``matching_setup`` and
383
+ ``matching_score``, which are unique to each score. In the following,
384
+ we will be using the `FLCSphericalMask` score, which is composed of
385
+ :py:meth:`flcSphericalMask_setup` and :py:meth:`corr_scoring`
386
+
387
+ >>> from tme.matching_exhaustive import MATCHING_EXHAUSTIVE_REGISTER
388
+ >>> funcs = MATCHING_EXHAUSTIVE_REGISTER.get("FLCSphericalMask")
389
+ >>> matching_setup, matching_score = funcs
390
+
391
+ Computed scores are flexibly analyzed by being passed through an analyzer. In the
392
+ following, we will use :py:class:`tme.analyzer.MaxScoreOverRotations` to
393
+ aggregate sores over rotations
394
+
395
+ >>> from tme.analyzer import MaxScoreOverRotations
396
+ >>> callback_class = MaxScoreOverRotations
397
+ >>> callback_class_args = {"score_threshold" : 0}
398
+
399
+ In case the entire template matching problem does not fit into memory, we can
400
+ determine the splitting procedure. In this case, we halv the first axis of the target
401
+ once. Splitting and ``job_schedule`` is typically computed using
402
+ :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
403
+
404
+ >>> target_splits = {0 : 1}
405
+
406
+ Finally, we can perform template matching. Note that the data
407
+ contained in ``matching_data`` will be destroyed when running the following
408
+
409
+ >>> from tme.matching_exhaustive import scan_subsets
410
+ >>> results = scan_subsets(
411
+ >>> matching_data = matching_data,
412
+ >>> matching_score = matching_score,
413
+ >>> matching_setup = matching_setup,
414
+ >>> callback_class = callback_class,
415
+ >>> callback_class_args = callback_class_args,
416
+ >>> target_splits = target_splits,
417
+ >>> )
418
+
419
+ The ``results`` tuple contains the output of the chosen analyzer.
420
+
421
+ See Also
422
+ --------
423
+ :py:meth:`tme.matching_utils.compute_parallelization_schedule`
1655
424
  """
1656
- target_splits = split_numpy_array_slices(
1657
- matching_data._target.shape, splits=target_splits
1658
- )
1659
- template_splits = split_numpy_array_slices(
1660
- matching_data._template.shape, splits=template_splits
1661
- )
1662
- target_pad = matching_data.target_padding(pad_target=pad_target_edges)
425
+ template_splits = split_shape(matching_data._template.shape, splits=template_splits)
426
+ target_splits = split_shape(matching_data._target.shape, splits=target_splits)
427
+ if (len(target_splits) > 1) and not pad_target_edges:
428
+ warnings.warn(
429
+ "Target splitting without padding target edges leads to unreliable "
430
+ "similarity estimates around the split border."
431
+ )
432
+ splits = tuple(product(target_splits, template_splits))
1663
433
 
1664
434
  outer_jobs, inner_jobs = job_schedule
1665
- results = Parallel(n_jobs=outer_jobs)(
1666
- delayed(_run_inner)(
1667
- backend_name=backend._backend_name,
1668
- backend_args=backend._backend_args,
1669
- matching_data=matching_data.subset_by_slice(
1670
- target_slice=target_split,
1671
- target_pad=target_pad,
1672
- template_slice=template_split,
1673
- ),
1674
- matching_score=matching_score,
1675
- matching_setup=matching_setup,
1676
- n_jobs=inner_jobs,
435
+ target_pad = matching_data.target_padding(pad_target=pad_target_edges)
436
+ if hasattr(be, "scan"):
437
+ corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
438
+ results = be.scan(
439
+ matching_data=matching_data,
440
+ splits=splits,
441
+ n_jobs=outer_jobs,
442
+ rotate_mask=matching_score != corr_scoring,
1677
443
  callback_class=callback_class,
1678
- callback_class_args=callback_class_args,
1679
- interpolation_order=interpolation_order,
1680
- pad_fourier=pad_fourier,
1681
- gpu_index=index % outer_jobs,
1682
- **kwargs,
1683
444
  )
1684
- for index, (target_split, template_split) in enumerate(
1685
- product(target_splits, template_splits)
445
+ else:
446
+ results = Parallel(n_jobs=outer_jobs)(
447
+ delayed(_wrap_backend(scan))(
448
+ backend_name=be._backend_name,
449
+ backend_args=be._backend_args,
450
+ matching_data=matching_data.subset_by_slice(
451
+ target_slice=target_split,
452
+ target_pad=target_pad,
453
+ template_slice=template_split,
454
+ ),
455
+ matching_score=matching_score,
456
+ matching_setup=matching_setup,
457
+ n_jobs=inner_jobs,
458
+ callback_class=callback_class,
459
+ callback_class_args=callback_class_args,
460
+ interpolation_order=interpolation_order,
461
+ pad_fourier=pad_fourier,
462
+ gpu_index=index % outer_jobs,
463
+ pad_template_filter=pad_template_filter,
464
+ )
465
+ for index, (target_split, template_split) in enumerate(splits)
1686
466
  )
1687
- )
1688
-
1689
- matching_data._target, matching_data._template = None, None
1690
- matching_data._target_mask, matching_data._template_mask = None, None
1691
467
 
1692
- candidates = None
468
+ matching_data._free_data()
1693
469
  if callback_class is not None:
1694
- candidates = callback_class.merge(
1695
- results, **callback_class_args, inner_merge=False
1696
- )
1697
-
1698
- return candidates
1699
-
1700
-
1701
- MATCHING_EXHAUSTIVE_REGISTER = {
1702
- "CC": (cc_setup, corr_scoring),
1703
- "LCC": (lcc_setup, corr_scoring),
1704
- "CORR": (corr_setup, corr_scoring),
1705
- "CAM": (cam_setup, corr_scoring),
1706
- "FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
1707
- "FLC": (flc_setup, flc_scoring),
1708
- "FLC2": (flc_setup, flc_scoring2),
1709
- "MCC": (mcc_setup, mcc_scoring),
1710
- }
470
+ return callback_class.merge(results, **callback_class_args)
471
+ return None
1711
472
 
1712
473
 
1713
474
  def register_matching_exhaustive(
@@ -1724,20 +485,17 @@ def register_matching_exhaustive(
1724
485
  matching : str
1725
486
  Name of the matching method.
1726
487
  matching_setup : Callable
1727
- The setup function associated with the name.
488
+ Corresponding setup function.
1728
489
  matching_scoring : Callable
1729
- The scoring function associated with the name.
490
+ Corresponing scoring function.
1730
491
  memory_class : MatchingMemoryUsage
1731
- The custom memory estimation class extending
1732
- :py:class:`tme.matching_memory.MatchingMemoryUsage`.
492
+ Child of :py:class:`tme.memory.MatchingMemoryUsage`.
1733
493
 
1734
494
  Raises
1735
495
  ------
1736
496
  ValueError
1737
- If a function with the name ``matching`` already exists in the registry.
1738
- ValueError
1739
- If ``memory_class`` is not a subclass of
1740
- :py:class:`tme.matching_memory.MatchingMemoryUsage`.
497
+ If a function with the name ``matching`` already exists in the registry, or
498
+ if ``memory_class`` is no child of :py:class:`tme.memory.MatchingMemoryUsage`.
1741
499
  """
1742
500
 
1743
501
  if matching in MATCHING_EXHAUSTIVE_REGISTER: