pytme 0.1.5__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. pytme-0.1.5.data/scripts/estimate_ram_usage.py +81 -0
  2. pytme-0.1.5.data/scripts/match_template.py +744 -0
  3. pytme-0.1.5.data/scripts/postprocess.py +279 -0
  4. pytme-0.1.5.data/scripts/preprocess.py +93 -0
  5. pytme-0.1.5.data/scripts/preprocessor_gui.py +729 -0
  6. pytme-0.1.5.dist-info/LICENSE +153 -0
  7. pytme-0.1.5.dist-info/METADATA +69 -0
  8. pytme-0.1.5.dist-info/RECORD +63 -0
  9. pytme-0.1.5.dist-info/WHEEL +5 -0
  10. pytme-0.1.5.dist-info/entry_points.txt +6 -0
  11. pytme-0.1.5.dist-info/top_level.txt +2 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +81 -0
  14. scripts/match_template.py +744 -0
  15. scripts/match_template_devel.py +788 -0
  16. scripts/postprocess.py +279 -0
  17. scripts/preprocess.py +93 -0
  18. scripts/preprocessor_gui.py +729 -0
  19. tme/__init__.py +6 -0
  20. tme/__version__.py +1 -0
  21. tme/analyzer.py +1144 -0
  22. tme/backends/__init__.py +134 -0
  23. tme/backends/cupy_backend.py +309 -0
  24. tme/backends/matching_backend.py +1154 -0
  25. tme/backends/npfftw_backend.py +763 -0
  26. tme/backends/pytorch_backend.py +526 -0
  27. tme/data/__init__.py +0 -0
  28. tme/data/c48n309.npy +0 -0
  29. tme/data/c48n527.npy +0 -0
  30. tme/data/c48n9.npy +0 -0
  31. tme/data/c48u1.npy +0 -0
  32. tme/data/c48u1153.npy +0 -0
  33. tme/data/c48u1201.npy +0 -0
  34. tme/data/c48u1641.npy +0 -0
  35. tme/data/c48u181.npy +0 -0
  36. tme/data/c48u2219.npy +0 -0
  37. tme/data/c48u27.npy +0 -0
  38. tme/data/c48u2947.npy +0 -0
  39. tme/data/c48u3733.npy +0 -0
  40. tme/data/c48u4749.npy +0 -0
  41. tme/data/c48u5879.npy +0 -0
  42. tme/data/c48u7111.npy +0 -0
  43. tme/data/c48u815.npy +0 -0
  44. tme/data/c48u83.npy +0 -0
  45. tme/data/c48u8649.npy +0 -0
  46. tme/data/c600v.npy +0 -0
  47. tme/data/c600vc.npy +0 -0
  48. tme/data/metadata.yaml +80 -0
  49. tme/data/quat_to_numpy.py +42 -0
  50. tme/data/scattering_factors.pickle +0 -0
  51. tme/density.py +2314 -0
  52. tme/extensions.cpython-311-darwin.so +0 -0
  53. tme/helpers.py +881 -0
  54. tme/matching_data.py +377 -0
  55. tme/matching_exhaustive.py +1553 -0
  56. tme/matching_memory.py +382 -0
  57. tme/matching_optimization.py +1123 -0
  58. tme/matching_utils.py +1180 -0
  59. tme/parser.py +429 -0
  60. tme/preprocessor.py +1291 -0
  61. tme/scoring.py +866 -0
  62. tme/structure.py +1428 -0
  63. tme/types.py +10 -0
@@ -0,0 +1,1553 @@
1
+ """ Implements template matching based on different metrics. The functions
2
+ implemented here are equivalent to what is implemented in `Fitter`,
3
+ but removing them out of the Fitter classes reduces the time it takes
4
+ when initializing the multiprocessing framework. This is useful in cases,
5
+ where the number of sampled rotations is low, but the number of input
6
+ arrays is high, e.g. split fitting and evaluating different templates.
7
+
8
+ Copyright (c) 2023 European Molecular Biology Laboratory
9
+
10
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
11
+ """
12
+ import os
13
+ import sys
14
+ from copy import deepcopy
15
+ from itertools import product
16
+ from typing import Callable, Tuple, Dict
17
+ from functools import wraps
18
+ from joblib import Parallel, delayed
19
+ from multiprocessing.managers import SharedMemoryManager
20
+
21
+ import numpy as np
22
+ from scipy.ndimage import laplace
23
+
24
+ from .analyzer import MaxScoreOverRotations
25
+ from .matching_utils import (
26
+ apply_convolution_mode,
27
+ handle_traceback,
28
+ split_numpy_array_slices,
29
+ conditional_execute,
30
+ )
31
+ from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
32
+ from .preprocessor import Preprocessor
33
+ from .matching_data import MatchingData
34
+ from .backends import backend
35
+ from .types import NDArray, CallbackClass
36
+
37
+ os.environ["MKL_NUM_THREADS"] = "1"
38
+ os.environ["OMP_NUM_THREADS"] = "1"
39
+ os.environ["PYFFTW_NUM_THREADS"] = "1"
40
+ os.environ["OPENBLAS_NUM_THREADS"] = "1"
41
+
42
+
43
+ def _run_inner(backend_name, backend_args, **kwargs):
44
+ from tme.backends import backend
45
+
46
+ backend.change_backend(backend_name, **backend_args)
47
+ # print("make sure to remove this")
48
+ # from tme.scoring import scan
49
+ return scan(**kwargs)
50
+
51
+
52
+ def cc_setup(
53
+ rfftn: Callable,
54
+ irfftn: Callable,
55
+ template: NDArray,
56
+ target: NDArray,
57
+ fast_shape: Tuple[int],
58
+ fast_ft_shape: Tuple[int],
59
+ real_dtype: type,
60
+ complex_dtype: type,
61
+ shared_memory_handler: Callable,
62
+ callback_class: Callable,
63
+ callback_class_args: Dict,
64
+ **kwargs,
65
+ ) -> Dict:
66
+ """
67
+ Setup to compute the cross-correlation between a target f and template g
68
+ defined as:
69
+
70
+ .. math::
71
+
72
+ \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
73
+
74
+
75
+ See Also
76
+ --------
77
+ :py:meth:`corr_scoring`
78
+ :py:class:`tme.matching_optimization.CrossCorrelation`
79
+ """
80
+ target_shape = target.shape
81
+ target_pad = backend.topleft_pad(target, fast_shape)
82
+ target_pad_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
83
+
84
+ rfftn(target_pad, target_pad_ft)
85
+
86
+ target_ft_out = backend.arr_to_sharedarr(
87
+ arr=target_pad_ft, shared_memory_handler=shared_memory_handler
88
+ )
89
+
90
+ template_out = backend.arr_to_sharedarr(
91
+ arr=template, shared_memory_handler=shared_memory_handler
92
+ )
93
+ inv_denominator_buffer = backend.arr_to_sharedarr(
94
+ arr=backend.preallocate_array(1, real_dtype) + 1,
95
+ shared_memory_handler=shared_memory_handler,
96
+ )
97
+ numerator2_buffer = backend.arr_to_sharedarr(
98
+ arr=backend.preallocate_array(1, real_dtype),
99
+ shared_memory_handler=shared_memory_handler,
100
+ )
101
+
102
+ target_ft_tuple = (target_ft_out, fast_ft_shape, complex_dtype)
103
+ template_tuple = (template_out, template.shape, real_dtype)
104
+
105
+ inv_denominator_tuple = (inv_denominator_buffer, (1,), real_dtype)
106
+ numerator2_tuple = (numerator2_buffer, (1,), real_dtype)
107
+
108
+ ret = {
109
+ "template": template_tuple,
110
+ "ft_target": target_ft_tuple,
111
+ "inv_denominator": inv_denominator_tuple,
112
+ "numerator2": numerator2_tuple,
113
+ "targetshape": target_shape,
114
+ "templateshape": template.shape,
115
+ "fast_shape": fast_shape,
116
+ "fast_ft_shape": fast_ft_shape,
117
+ "real_dtype": real_dtype,
118
+ "complex_dtype": complex_dtype,
119
+ "callback_class": callback_class,
120
+ "callback_class_args": callback_class_args,
121
+ "use_memmap": kwargs.get("use_memmap", False),
122
+ "temp_dir": kwargs.get("temp_dir", None),
123
+ }
124
+
125
+ return ret
126
+
127
+
128
+ def lcc_setup(**kwargs) -> Dict:
129
+ """
130
+ Setup to compute the cross-correlation between a laplace transformed target f
131
+ and laplace transformed template g defined as:
132
+
133
+ .. math::
134
+
135
+ \\mathcal{F}^{-1}(\\mathcal{F}(\\nabla^{2}f) \\cdot \\mathcal{F}(\\nabla^{2} g)^*)
136
+
137
+
138
+ See Also
139
+ --------
140
+ :py:meth:`corr_scoring`
141
+ :py:class:`tme.matching_optimization.LaplaceCrossCorrelation`
142
+ """
143
+ from cupyx.scipy.ndimage import laplace
144
+ kwargs["target"] = laplace(kwargs["target"], mode="wrap")
145
+ kwargs["template"] = laplace(kwargs["template"], mode="wrap")
146
+ return flcSphericalMask_setup(**kwargs)
147
+
148
+
149
+ def corr_setup(
150
+ rfftn: Callable,
151
+ irfftn: Callable,
152
+ template: NDArray,
153
+ template_mask: NDArray,
154
+ target: NDArray,
155
+ fast_shape: Tuple[int],
156
+ fast_ft_shape: Tuple[int],
157
+ real_dtype: type,
158
+ complex_dtype: type,
159
+ shared_memory_handler: Callable,
160
+ callback_class: Callable,
161
+ callback_class_args: Dict,
162
+ **kwargs,
163
+ ) -> Dict:
164
+ """
165
+ Setup to compute a normalized cross-correlation between a target f and a template g.
166
+
167
+ .. math::
168
+
169
+ \\frac{CC(f,g) - \\overline{g} \\cdot CC(f, m)}
170
+ {(CC(f^2, m) - \\frac{CC(f, m)^2}{N_g}) \\cdot \\sigma_{g}}
171
+
172
+ Where:
173
+
174
+ .. math::
175
+
176
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
177
+
178
+ and m is a mask with the same dimension as the template filled with ones.
179
+
180
+ References
181
+ ----------
182
+ .. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light
183
+ and Magic.
184
+
185
+ See Also
186
+ --------
187
+ :py:meth:`corr_scoring`
188
+ :py:class:`tme.matching_optimization.NormalizedCrossCorrelation`.
189
+ """
190
+ target_pad = backend.topleft_pad(target, fast_shape)
191
+
192
+ # The exact composition of the denominator is debatable
193
+ # scikit-image match_template multiplies the running sum of the target
194
+ # with a scaling factor derived from the template. This is probably appropriate
195
+ # in pattern matching situations where the template exists in the target
196
+ window_template = backend.topleft_pad(template_mask, fast_shape)
197
+ ft_window_template = backend.preallocate_array(fast_ft_shape, complex_dtype)
198
+ rfftn(window_template, ft_window_template)
199
+ window_template = None
200
+
201
+ # Target and squared target window sums
202
+ ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
203
+ ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
204
+ denominator = backend.preallocate_array(fast_shape, real_dtype)
205
+ target_window_sum = backend.preallocate_array(fast_shape, real_dtype)
206
+ rfftn(target_pad, ft_target)
207
+
208
+ rfftn(backend.square(target_pad), ft_target2)
209
+ backend.multiply(ft_target2, ft_window_template, out=ft_target2)
210
+ irfftn(ft_target2, denominator)
211
+
212
+ backend.multiply(ft_target, ft_window_template, out=ft_window_template)
213
+ irfftn(ft_window_template, target_window_sum)
214
+
215
+ target_pad, ft_target2, ft_window_template = None, None, None
216
+
217
+ # Normalizing constants
218
+ template_mean = backend.mean(template)
219
+ template_volume = np.prod(template.shape)
220
+ template_ssd = backend.sum(
221
+ backend.square(backend.subtract(template, template_mean))
222
+ )
223
+
224
+ # Final numerator is score - numerator2
225
+ numerator2 = backend.multiply(target_window_sum, template_mean)
226
+
227
+ # Compute denominator
228
+ backend.multiply(target_window_sum, target_window_sum, out=target_window_sum)
229
+ backend.divide(target_window_sum, template_volume, out=target_window_sum)
230
+
231
+ backend.subtract(denominator, target_window_sum, out=denominator)
232
+ backend.multiply(denominator, template_ssd, out=denominator)
233
+ backend.maximum(denominator, 0, out=denominator)
234
+ backend.sqrt(denominator, out=denominator)
235
+ target_window_sum = None
236
+
237
+ # Invert denominator to compute final score as product
238
+ denominator_mask = denominator > backend.eps(denominator.dtype)
239
+ inv_denominator = backend.preallocate_array(fast_shape, real_dtype)
240
+ inv_denominator[denominator_mask] = 1 / denominator[denominator_mask]
241
+
242
+ # Convert arrays used in subsequent fitting to SharedMemory objects
243
+ template_buffer = backend.arr_to_sharedarr(
244
+ arr=template, shared_memory_handler=shared_memory_handler
245
+ )
246
+ target_ft_buffer = backend.arr_to_sharedarr(
247
+ arr=ft_target, shared_memory_handler=shared_memory_handler
248
+ )
249
+ inv_denominator_buffer = backend.arr_to_sharedarr(
250
+ arr=inv_denominator, shared_memory_handler=shared_memory_handler
251
+ )
252
+ numerator2_buffer = backend.arr_to_sharedarr(
253
+ arr=numerator2, shared_memory_handler=shared_memory_handler
254
+ )
255
+
256
+ template_tuple = (template_buffer, deepcopy(template.shape), real_dtype)
257
+ target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
258
+
259
+ inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
260
+ numerator2_tuple = (numerator2_buffer, fast_shape, real_dtype)
261
+
262
+ ft_target, inv_denominator, numerator2 = None, None, None
263
+
264
+ ret = {
265
+ "template": template_tuple,
266
+ "ft_target": target_ft_tuple,
267
+ "inv_denominator": inv_denominator_tuple,
268
+ "numerator2": numerator2_tuple,
269
+ "targetshape": deepcopy(target.shape),
270
+ "templateshape": deepcopy(template.shape),
271
+ "fast_shape": fast_shape,
272
+ "fast_ft_shape": fast_ft_shape,
273
+ "real_dtype": real_dtype,
274
+ "complex_dtype": complex_dtype,
275
+ "callback_class": callback_class,
276
+ "callback_class_args": callback_class_args,
277
+ "template_mean": kwargs.get("template_mean", template_mean),
278
+ }
279
+
280
+ return ret
281
+
282
+
283
+ def cam_setup(**kwargs):
284
+ """
285
+ Setup to compute a normalized cross-correlation between a target f and a template g
286
+ over their means. In practice this can be expressed like the cross-correlation
287
+ CORR defined in :py:meth:`corr_scoring`, so that:
288
+
289
+ Notes
290
+ -----
291
+
292
+ .. math::
293
+
294
+ \\text{CORR}(f-\\overline{f}, g - \\overline{g})
295
+
296
+ Where
297
+
298
+ .. math::
299
+
300
+ \\overline{f}, \\overline{g}
301
+
302
+ are the mean of f and g respectively.
303
+
304
+ References
305
+ ----------
306
+ .. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light
307
+ and Magic.
308
+
309
+ See Also
310
+ --------
311
+ :py:meth:`corr_scoring`.
312
+ :py:class:`tme.matching_optimization.NormalizedCrossCorrelationMean`.
313
+ """
314
+ kwargs["template"] = kwargs["template"] - kwargs["template"].mean()
315
+ return corr_setup(**kwargs)
316
+
317
+
318
+ def flc_setup(
319
+ rfftn: Callable,
320
+ irfftn: Callable,
321
+ template: NDArray,
322
+ template_mask: NDArray,
323
+ target: NDArray,
324
+ fast_shape: Tuple[int],
325
+ fast_ft_shape: Tuple[int],
326
+ real_dtype: type,
327
+ complex_dtype: type,
328
+ shared_memory_handler: Callable,
329
+ callback_class: Callable,
330
+ callback_class_args: Dict,
331
+ **kwargs,
332
+ ) -> Dict:
333
+ """
334
+ Setup to compute a normalized cross-correlation score of a target f a template g
335
+ and a mask m:
336
+
337
+ .. math::
338
+
339
+ \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
340
+ {N_m * \\sqrt{
341
+ \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
342
+ }
343
+
344
+ Where:
345
+
346
+ .. math::
347
+
348
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
349
+
350
+ and Nm is the number of voxels within the template mask m.
351
+
352
+ References
353
+ ----------
354
+ .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
355
+ Microsc. Microanal. 26, 2516 (2020)
356
+ .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
357
+ and F. Förster, J. Struct. Biol. 178, 177 (2012).
358
+
359
+ See Also
360
+ --------
361
+ :py:meth:`flc_scoring`
362
+ """
363
+ target_pad = backend.topleft_pad(target, fast_shape)
364
+
365
+ # Target and squared target window sums
366
+ ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
367
+ ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
368
+ rfftn(target_pad, ft_target)
369
+ rfftn(backend.square(target_pad), ft_target2)
370
+
371
+ # Convert arrays used in subsequent fitting to SharedMemory objects
372
+ ft_target = backend.arr_to_sharedarr(
373
+ arr=ft_target, shared_memory_handler=shared_memory_handler
374
+ )
375
+ ft_target2 = backend.arr_to_sharedarr(
376
+ arr=ft_target2, shared_memory_handler=shared_memory_handler
377
+ )
378
+
379
+ template_mask = template_mask > 0
380
+ template_mean = backend.mean(template[template_mask])
381
+ template_std = backend.std(template[template_mask])
382
+ template_mask = backend.astype(template_mask, real_dtype)
383
+
384
+ backend.divide(template - template_mean, template_std, out=template)
385
+ backend.multiply(template, template_mask, out=template)
386
+
387
+ template_buffer = backend.arr_to_sharedarr(
388
+ arr=template, shared_memory_handler=shared_memory_handler
389
+ )
390
+ template_mask_buffer = backend.arr_to_sharedarr(
391
+ arr=template_mask, shared_memory_handler=shared_memory_handler
392
+ )
393
+
394
+ template_tuple = (template_buffer, template.shape, real_dtype)
395
+ template_mask_tuple = (template_mask_buffer, template_mask.shape, real_dtype)
396
+
397
+ target_ft_tuple = (ft_target, fast_ft_shape, complex_dtype)
398
+ target_ft2_tuple = (ft_target2, fast_ft_shape, complex_dtype)
399
+
400
+ ret = {
401
+ "template": template_tuple,
402
+ "template_mask": template_mask_tuple,
403
+ "ft_target": target_ft_tuple,
404
+ "ft_target2": target_ft2_tuple,
405
+ "targetshape": target.shape,
406
+ "templateshape": template.shape,
407
+ "fast_shape": fast_shape,
408
+ "fast_ft_shape": fast_ft_shape,
409
+ "real_dtype": real_dtype,
410
+ "complex_dtype": complex_dtype,
411
+ "callback_class": callback_class,
412
+ "callback_class_args": callback_class_args,
413
+ }
414
+
415
+ return ret
416
+
417
+
418
+ def flcSphericalMask_setup(
419
+ rfftn: Callable,
420
+ irfftn: Callable,
421
+ template: NDArray,
422
+ template_mask: NDArray,
423
+ target: NDArray,
424
+ fast_shape: Tuple[int],
425
+ fast_ft_shape: Tuple[int],
426
+ real_dtype: type,
427
+ complex_dtype: type,
428
+ shared_memory_handler: Callable,
429
+ callback_class: Callable,
430
+ callback_class_args: Dict,
431
+ **kwargs,
432
+ ) -> Dict:
433
+ """
434
+ Like :py:meth:`flc_setup` but for rotation invariant masks. In such cases
435
+ the score can be computed quicker by not computing the fourier transforms
436
+ of the mask for each rotation.
437
+
438
+ .. math::
439
+
440
+ \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
441
+ {N_m * \\sqrt{
442
+ \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
443
+ }
444
+
445
+ Where:
446
+
447
+ .. math::
448
+
449
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
450
+
451
+ and Nm is the number of voxels within the template mask m.
452
+
453
+ References
454
+ ----------
455
+ .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
456
+ Microsc. Microanal. 26, 2516 (2020)
457
+ .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
458
+ and F. Förster, J. Struct. Biol. 178, 177 (2012).
459
+
460
+ See Also
461
+ --------
462
+ :py:meth:`corr_scoring`
463
+ """
464
+ target_pad = backend.topleft_pad(target, fast_shape)
465
+ template_mask_pad = backend.topleft_pad(template_mask, fast_shape)
466
+
467
+ # Target and squared target window sums
468
+ ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
469
+ ft_template_mask = backend.preallocate_array(fast_ft_shape, complex_dtype)
470
+ ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
471
+
472
+ temp = backend.preallocate_array(fast_shape, real_dtype)
473
+ temp2 = backend.preallocate_array(fast_shape, real_dtype)
474
+ numerator2 = backend.preallocate_array(1, real_dtype)
475
+
476
+ eps = backend.eps(real_dtype)
477
+ n_observations = backend.sum(template_mask_pad)
478
+ rfftn(template_mask_pad, ft_template_mask)
479
+
480
+ # Variance part denominator
481
+ rfftn(backend.square(target_pad), ft_target)
482
+ backend.multiply(ft_target, ft_template_mask, out=ft_temp)
483
+ irfftn(ft_temp, temp2)
484
+ backend.divide(temp2, n_observations, out=temp2)
485
+
486
+ # Mean part denominator
487
+ rfftn(target_pad, ft_target)
488
+ backend.multiply(ft_target, ft_template_mask, out=ft_temp)
489
+ irfftn(ft_temp, temp)
490
+ backend.divide(temp, n_observations, out=temp)
491
+ backend.square(temp, out=temp)
492
+
493
+ backend.subtract(temp2, temp, out=temp)
494
+
495
+ backend.maximum(temp, 0.0, out=temp)
496
+ backend.sqrt(temp, out=temp)
497
+ backend.multiply(temp, n_observations, out=temp)
498
+
499
+ tol = 1e3 * eps * backend.max(backend.abs(temp))
500
+ nonzero_indices = temp > tol
501
+
502
+ backend.fill(temp2, 0)
503
+ temp2[nonzero_indices] = 1 / temp[nonzero_indices]
504
+
505
+ template_mask = template_mask > 0
506
+ template_mean = backend.mean(template[template_mask])
507
+ template_std = backend.std(template[template_mask])
508
+
509
+ template = backend.divide(backend.subtract(template, template_mean), template_std)
510
+ backend.multiply(template, template_mask, out=template)
511
+
512
+ # Convert arrays used in subsequent fitting to SharedMemory objects
513
+ template_buffer = backend.arr_to_sharedarr(
514
+ arr=template, shared_memory_handler=shared_memory_handler
515
+ )
516
+ target_ft_buffer = backend.arr_to_sharedarr(
517
+ arr=ft_target, shared_memory_handler=shared_memory_handler
518
+ )
519
+ inv_denominator_buffer = backend.arr_to_sharedarr(
520
+ arr=temp2, shared_memory_handler=shared_memory_handler
521
+ )
522
+ numerator2_buffer = backend.arr_to_sharedarr(
523
+ arr=numerator2, shared_memory_handler=shared_memory_handler
524
+ )
525
+
526
+ template_tuple = (template_buffer, template.shape, real_dtype)
527
+ target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
528
+
529
+ inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
530
+ numerator2_tuple = (numerator2_buffer, (1,), real_dtype)
531
+
532
+ ret = {
533
+ "template": template_tuple,
534
+ "ft_target": target_ft_tuple,
535
+ "inv_denominator": inv_denominator_tuple,
536
+ "numerator2": numerator2_tuple,
537
+ "targetshape": target.shape,
538
+ "templateshape": template.shape,
539
+ "fast_shape": fast_shape,
540
+ "fast_ft_shape": fast_ft_shape,
541
+ "real_dtype": real_dtype,
542
+ "complex_dtype": complex_dtype,
543
+ "callback_class": callback_class,
544
+ "callback_class_args": callback_class_args,
545
+ }
546
+
547
+ return ret
548
+
549
+
550
+ def mcc_setup(
551
+ rfftn: Callable,
552
+ irfftn: Callable,
553
+ template: NDArray,
554
+ template_mask: NDArray,
555
+ target: NDArray,
556
+ target_mask: NDArray,
557
+ fast_shape: Tuple[int],
558
+ fast_ft_shape: Tuple[int],
559
+ real_dtype: type,
560
+ complex_dtype: type,
561
+ shared_memory_handler: Callable,
562
+ callback_class: Callable,
563
+ callback_class_args: Dict,
564
+ **kwargs,
565
+ ) -> Dict:
566
+ """
567
+ Setup to compute a normalized cross-correlation score with masks
568
+ for both template and target.
569
+
570
+ .. math::
571
+
572
+ \\frac{
573
+ CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
574
+ }{
575
+ \\sqrt{
576
+ (CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
577
+ (CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
578
+ }
579
+ }
580
+
581
+ Where:
582
+
583
+ .. math::
584
+
585
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
586
+
587
+
588
+ References
589
+ ----------
590
+ .. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
591
+
592
+ See Also
593
+ --------
594
+ :py:meth:`mcc_scoring`
595
+ :py:class:`tme.matching_optimization.MaskedCrossCorrelation`
596
+ """
597
+ target = backend.multiply(target, target_mask > 0, out=target)
598
+
599
+ target_pad = backend.topleft_pad(target, fast_shape)
600
+ target_mask_pad = backend.topleft_pad(target_mask, fast_shape)
601
+
602
+ target_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
603
+ rfftn(target_pad, target_ft)
604
+ target_ft_buffer = backend.arr_to_sharedarr(
605
+ arr=target_ft, shared_memory_handler=shared_memory_handler
606
+ )
607
+
608
+ target_ft2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
609
+ rfftn(backend.square(target_pad), target_ft2)
610
+ target_ft2_buffer = backend.arr_to_sharedarr(
611
+ arr=target_ft2, shared_memory_handler=shared_memory_handler
612
+ )
613
+
614
+ target_mask_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
615
+ rfftn(target_mask_pad, target_mask_ft)
616
+ target_mask_ft_buffer = backend.arr_to_sharedarr(
617
+ arr=target_mask_ft, shared_memory_handler=shared_memory_handler
618
+ )
619
+
620
+ template_buffer = backend.arr_to_sharedarr(
621
+ arr=template, shared_memory_handler=shared_memory_handler
622
+ )
623
+ template_mask_buffer = backend.arr_to_sharedarr(
624
+ arr=template_mask, shared_memory_handler=shared_memory_handler
625
+ )
626
+
627
+ template_tuple = (template_buffer, template.shape, real_dtype)
628
+ template_mask_tuple = (template_mask_buffer, template.shape, real_dtype)
629
+
630
+ target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
631
+ target_ft2_tuple = (target_ft2_buffer, fast_ft_shape, complex_dtype)
632
+ target_mask_ft_tuple = (target_mask_ft_buffer, fast_ft_shape, complex_dtype)
633
+
634
+ ret = {
635
+ "template": template_tuple,
636
+ "template_mask": template_mask_tuple,
637
+ "ft_target": target_ft_tuple,
638
+ "ft_target2": target_ft2_tuple,
639
+ "ft_target_mask": target_mask_ft_tuple,
640
+ "targetshape": target.shape,
641
+ "templateshape": template.shape,
642
+ "fast_shape": fast_shape,
643
+ "fast_ft_shape": fast_ft_shape,
644
+ "real_dtype": real_dtype,
645
+ "complex_dtype": complex_dtype,
646
+ "callback_class": callback_class,
647
+ "callback_class_args": callback_class_args,
648
+ }
649
+
650
+ return ret
651
+
652
+
653
+ def corr_scoring(
654
+ template: Tuple[type, Tuple[int], type],
655
+ ft_target: Tuple[type, Tuple[int], type],
656
+ inv_denominator: Tuple[type, Tuple[int], type],
657
+ numerator2: Tuple[type, Tuple[int], type],
658
+ template_filter: Tuple[type, Tuple[int], type],
659
+ targetshape: Tuple[int],
660
+ templateshape: Tuple[int],
661
+ fast_shape: Tuple[int],
662
+ fast_ft_shape: Tuple[int],
663
+ rotations: NDArray,
664
+ real_dtype: type,
665
+ complex_dtype: type,
666
+ callback_class: CallbackClass,
667
+ callback_class_args: Dict,
668
+ interpolation_order: int,
669
+ convolution_mode: str = "full",
670
+ **kwargs,
671
+ ) -> CallbackClass:
672
+ """
673
+ Calculates a normalized cross-correlation between a target f and a template g.
674
+
675
+ .. math::
676
+
677
+ (CC(f,g) - numerator2) \\cdot inv\\_denominator
678
+
679
+ Parameters
680
+ ----------
681
+ template : Tuple[type, Tuple[int], type]
682
+ Tuple containing a pointer to the template data, its shape, and its datatype.
683
+ ft_target : Tuple[type, Tuple[int], type]
684
+ Tuple containing a pointer to the fourier tranform of the target,
685
+ its shape, and its datatype.
686
+ inv_denominator : Tuple[type, Tuple[int], type]
687
+ Tuple containing a pointer to the inverse denominator data, its shape, and its
688
+ datatype.
689
+ numerator2 : Tuple[type, Tuple[int], type]
690
+ Tuple containing a pointer to the numerator2 data, its shape, and its datatype.
691
+ targetshape : Tuple[int]
692
+ The shape of the target.
693
+ templateshape : Tuple[int]
694
+ The shape of the template.
695
+ fast_shape : Tuple[int]
696
+ The shape for fast Fourier transform.
697
+ fast_ft_shape : Tuple[int]
698
+ The shape for fast Fourier transform of the target.
699
+ rotations : NDArray
700
+ Array containing the rotation matrices to be applied on the template.
701
+ real_dtype : type
702
+ Data type for the real part of the array.
703
+ complex_dtype : type
704
+ Data type for the complex part of the array.
705
+ callback_class : CallbackClass
706
+ A callable class or function for processing the results after each
707
+ rotation.
708
+ callback_class_args : Dict
709
+ Dictionary of arguments to be passed to the callback class if it's
710
+ instantiable.
711
+ interpolation_order : int
712
+ The order of interpolation to be used while rotating the template.
713
+ convolution_mode : str, optional
714
+ Mode to use for convolution, default is "full".
715
+ **kwargs :
716
+ Additional arguments to be passed to the function.
717
+
718
+ Returns
719
+ -------
720
+ CallbackClass
721
+ If callback_class was provided an instance of callback_class otherwise None.
722
+
723
+ See Also
724
+ --------
725
+ :py:meth:`cc_setup`
726
+ :py:meth:`corr_setup`
727
+ :py:meth:`cam_setup`
728
+ :py:meth:`flcSphericalMask_setup`
729
+ """
730
+ template_buffer, template_shape, template_dtype = template
731
+ ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
732
+ inv_denominator_buffer, inv_denominator_pointer_shape, _ = inv_denominator
733
+ numerator2_buffer, numerator2_shape, _ = numerator2
734
+ filter_buffer, filter_shape, filter_dtype = template_filter
735
+
736
+ if callback_class is not None and isinstance(callback_class, type):
737
+ callback = callback_class(**callback_class_args)
738
+ elif not isinstance(callback_class, type):
739
+ callback = callback_class
740
+
741
+ # Retrieve objects from shared memory
742
+ template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
743
+ ft_target = backend.sharedarr_to_arr(
744
+ ft_target_shape, ft_target_dtype, ft_target_buffer
745
+ )
746
+ inv_denominator = backend.sharedarr_to_arr(
747
+ inv_denominator_pointer_shape, template_dtype, inv_denominator_buffer
748
+ )
749
+ numerator2 = backend.sharedarr_to_arr(
750
+ numerator2_shape, template_dtype, numerator2_buffer
751
+ )
752
+ template_filter = backend.sharedarr_to_arr(
753
+ filter_shape, filter_dtype, filter_buffer
754
+ )
755
+
756
+ arr = backend.preallocate_array(fast_shape, real_dtype)
757
+ ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
758
+
759
+ rfftn, irfftn = backend.build_fft(
760
+ fast_shape=fast_shape,
761
+ fast_ft_shape=fast_ft_shape,
762
+ real_dtype=real_dtype,
763
+ complex_dtype=complex_dtype,
764
+ fftargs=kwargs.get("fftargs", {}),
765
+ temp_real=arr,
766
+ temp_fft=ft_temp,
767
+ )
768
+
769
+ norm_numerator = (backend.sum(numerator2) != 0) & (backend.size(numerator2) != 1)
770
+ norm_denominator = (backend.sum(inv_denominator) != 1) & (
771
+ backend.size(inv_denominator) != 1
772
+ )
773
+ filter_template = backend.size(template_filter) != 0
774
+
775
+ norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
776
+ norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
777
+ template_filter_func = conditional_execute(backend.multiply, filter_template)
778
+
779
+ axis = tuple(range(arr.ndim))
780
+ fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
781
+ fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
782
+
783
+ for index in range(rotations.shape[0]):
784
+ rotation = rotations[index]
785
+ backend.fill(arr, 0)
786
+ backend.rotate_array(
787
+ arr=template,
788
+ rotation_matrix=rotation,
789
+ out=arr,
790
+ use_geometric_center=False,
791
+ order=interpolation_order,
792
+ )
793
+ rfftn(arr, ft_temp)
794
+ template_filter_func(ft_temp, template_filter, out=ft_temp)
795
+
796
+ backend.multiply(ft_target, ft_temp, out=ft_temp)
797
+ irfftn(ft_temp, arr)
798
+
799
+ norm_func_numerator(arr, numerator2, out=arr)
800
+ norm_func_denominator(arr, inv_denominator, out=arr)
801
+
802
+ if fourier_shift_scores:
803
+ arr = backend.roll(arr, shift=fourier_shift, axis=axis)
804
+
805
+ score = apply_convolution_mode(
806
+ arr, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
807
+ )
808
+
809
+ if callback_class is not None:
810
+ callback(
811
+ score,
812
+ rotation_matrix=rotation,
813
+ rotation_index=index,
814
+ **callback_class_args,
815
+ )
816
+
817
+ return callback
818
+
819
+
820
+ def flc_scoring(
821
+ template: Tuple[type, Tuple[int], type],
822
+ template_mask: Tuple[type, Tuple[int], type],
823
+ ft_target: Tuple[type, Tuple[int], type],
824
+ ft_target2: Tuple[type, Tuple[int], type],
825
+ template_filter: Tuple[type, Tuple[int], type],
826
+ targetshape: Tuple[int],
827
+ templateshape: Tuple[int],
828
+ fast_shape: Tuple[int],
829
+ fast_ft_shape: Tuple[int],
830
+ rotations: NDArray,
831
+ real_dtype: type,
832
+ complex_dtype: type,
833
+ callback_class: CallbackClass,
834
+ callback_class_args: Dict,
835
+ interpolation_order: int,
836
+ **kwargs,
837
+ ) -> CallbackClass:
838
+ """
839
+ Computes a normalized cross-correlation score of a target f a template g
840
+ and a mask m:
841
+
842
+ .. math::
843
+
844
+ \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
845
+ {N_m * \\sqrt{
846
+ \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
847
+ }
848
+
849
+ Where:
850
+
851
+ .. math::
852
+
853
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
854
+
855
+ and Nm is the number of voxels within the template mask m.
856
+
857
+ References
858
+ ----------
859
+ .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
860
+ Microsc. Microanal. 26, 2516 (2020)
861
+ .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
862
+ and F. Förster, J. Struct. Biol. 178, 177 (2012).
863
+ """
864
+ template_buffer, template_shape, template_dtype = template
865
+ template_mask_buffer, *_ = template_mask
866
+ filter_buffer, filter_shape, filter_dtype = template_filter
867
+
868
+ ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
869
+ ft_target2_buffer, *_ = ft_target2
870
+
871
+ if callback_class is not None and isinstance(callback_class, type):
872
+ callback = callback_class(**callback_class_args)
873
+ elif not isinstance(callback_class, type):
874
+ callback = callback_class
875
+
876
+ # Retrieve objects from shared memory
877
+ template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
878
+ template_mask = backend.sharedarr_to_arr(
879
+ template_shape, template_dtype, template_mask_buffer
880
+ )
881
+ ft_target = backend.sharedarr_to_arr(
882
+ ft_target_shape, ft_target_dtype, ft_target_buffer
883
+ )
884
+ ft_target2 = backend.sharedarr_to_arr(
885
+ ft_target_shape, ft_target_dtype, ft_target2_buffer
886
+ )
887
+ template_filter = backend.sharedarr_to_arr(
888
+ filter_shape, filter_dtype, filter_buffer
889
+ )
890
+
891
+ arr = backend.preallocate_array(fast_shape, real_dtype)
892
+ temp = backend.preallocate_array(fast_shape, real_dtype)
893
+ temp2 = backend.preallocate_array(fast_shape, real_dtype)
894
+
895
+ ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
896
+ ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
897
+
898
+ rfftn, irfftn = backend.build_fft(
899
+ fast_shape=fast_shape,
900
+ fast_ft_shape=fast_ft_shape,
901
+ real_dtype=real_dtype,
902
+ complex_dtype=complex_dtype,
903
+ fftargs=kwargs.get("fftargs", {}),
904
+ temp_real=arr,
905
+ temp_fft=ft_temp,
906
+ )
907
+ eps = backend.eps(real_dtype)
908
+ filter_template = backend.size(template_filter) != 0
909
+ template_filter_func = conditional_execute(backend.multiply, filter_template)
910
+
911
+ axis = tuple(range(arr.ndim))
912
+ fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
913
+ fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
914
+
915
+ for index in range(rotations.shape[0]):
916
+ rotation = rotations[index]
917
+ backend.fill(arr, 0)
918
+ backend.fill(temp, 0)
919
+ backend.rotate_array(
920
+ arr=template,
921
+ arr_mask=template_mask,
922
+ rotation_matrix=rotation,
923
+ out=arr,
924
+ out_mask=temp,
925
+ use_geometric_center=False,
926
+ order=interpolation_order,
927
+ )
928
+ n_observations = backend.sum(temp)
929
+
930
+ rfftn(temp, ft_temp)
931
+
932
+ backend.multiply(ft_target, ft_temp, out=ft_denom)
933
+ irfftn(ft_denom, temp)
934
+ backend.divide(temp, n_observations, out=temp)
935
+ backend.square(temp, out=temp)
936
+
937
+ backend.multiply(ft_target2, ft_temp, out=ft_denom)
938
+ irfftn(ft_denom, temp2)
939
+ backend.divide(temp2, n_observations, out=temp2)
940
+
941
+ backend.subtract(temp2, temp, out=temp)
942
+ backend.maximum(temp, 0.0, out=temp)
943
+ backend.sqrt(temp, out=temp)
944
+ backend.multiply(temp, n_observations, out=temp)
945
+
946
+ rfftn(arr, ft_temp)
947
+ template_filter_func(ft_temp, template_filter, out=ft_temp)
948
+ backend.multiply(ft_target, ft_temp, out=ft_temp)
949
+ irfftn(ft_temp, arr)
950
+
951
+ tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
952
+ nonzero_indices = temp > tol
953
+ backend.fill(temp2, 0)
954
+ temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
955
+
956
+ convolution_mode = kwargs.get("convolution_mode", "full")
957
+
958
+ if fourier_shift_scores:
959
+ temp2 = backend.roll(temp2, shift=fourier_shift, axis=axis)
960
+
961
+ score = apply_convolution_mode(
962
+ temp2, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
963
+ )
964
+
965
+ if callback_class is not None:
966
+ callback(
967
+ score,
968
+ rotation_matrix=rotation,
969
+ rotation_index=index,
970
+ **callback_class_args,
971
+ )
972
+
973
+ return callback
974
+
975
+
976
+ def mcc_scoring(
977
+ template: Tuple[type, Tuple[int], type],
978
+ template_mask: Tuple[type, Tuple[int], type],
979
+ ft_target: Tuple[type, Tuple[int], type],
980
+ ft_target2: Tuple[type, Tuple[int], type],
981
+ ft_target_mask: Tuple[type, Tuple[int], type],
982
+ template_filter: Tuple[type, Tuple[int], type],
983
+ targetshape: Tuple[int],
984
+ templateshape: Tuple[int],
985
+ fast_shape: Tuple[int],
986
+ fast_ft_shape: Tuple[int],
987
+ rotations: NDArray,
988
+ real_dtype: type,
989
+ complex_dtype: type,
990
+ callback_class: CallbackClass,
991
+ callback_class_args: type,
992
+ interpolation_order: int,
993
+ overlap_ratio: float = 0.3,
994
+ **kwargs,
995
+ ) -> CallbackClass:
996
+ """
997
+ Computes a cross-correlation score with masks for both template and target.
998
+
999
+ .. math::
1000
+
1001
+ \\frac{
1002
+ CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
1003
+ }{
1004
+ \\sqrt{
1005
+ (CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
1006
+ (CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
1007
+ }
1008
+ }
1009
+
1010
+ Where:
1011
+
1012
+ .. math::
1013
+
1014
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
1015
+
1016
+
1017
+ References
1018
+ ----------
1019
+ .. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
1020
+ .. [2] https://scikit-image.org/docs/stable/api/skimage.registration.html
1021
+
1022
+ See Also
1023
+ --------
1024
+ :py:class:`tme.matching_optimization.MaskedCrossCorrelation`
1025
+ """
1026
+ template_buffer, template_shape, template_dtype = template
1027
+ ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
1028
+ ft_target2_buffer, ft_target_shape, ft_target_dtype = ft_target2
1029
+ template_mask_buffer, _, _ = template
1030
+ ft_target_mask_buffer, _, _ = ft_target
1031
+ filter_buffer, filter_shape, filter_dtype = template_filter
1032
+
1033
+ if callback_class is not None and isinstance(callback_class, type):
1034
+ callback = callback_class(**callback_class_args)
1035
+ elif not isinstance(callback_class, type):
1036
+ callback = callback_class
1037
+
1038
+ # Retrieve objects from shared memory
1039
+ template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
1040
+ target_ft = backend.sharedarr_to_arr(
1041
+ ft_target_shape, ft_target_dtype, ft_target_buffer
1042
+ )
1043
+ target_ft2 = backend.sharedarr_to_arr(
1044
+ ft_target_shape, ft_target_dtype, ft_target2_buffer
1045
+ )
1046
+ template_mask = backend.sharedarr_to_arr(
1047
+ template_shape, template_dtype, template_mask_buffer
1048
+ )
1049
+ target_mask_ft = backend.sharedarr_to_arr(
1050
+ ft_target_shape, ft_target_dtype, ft_target_mask_buffer
1051
+ )
1052
+ template_filter = backend.sharedarr_to_arr(
1053
+ filter_shape, filter_dtype, filter_buffer
1054
+ )
1055
+
1056
+ axes = tuple(range(template.ndim))
1057
+ eps = backend.eps(real_dtype)
1058
+
1059
+ # Allocate score and process specific arrays
1060
+ template_rot = backend.preallocate_array(fast_shape, real_dtype)
1061
+ mask_overlap = backend.preallocate_array(fast_shape, real_dtype)
1062
+ numerator = backend.preallocate_array(fast_shape, real_dtype)
1063
+
1064
+ temp = backend.preallocate_array(fast_shape, real_dtype)
1065
+ temp2 = backend.preallocate_array(fast_shape, real_dtype)
1066
+ temp3 = backend.preallocate_array(fast_shape, real_dtype)
1067
+ temp_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
1068
+
1069
+ rfftn, irfftn = backend.build_fft(
1070
+ fast_shape=fast_shape,
1071
+ fast_ft_shape=fast_ft_shape,
1072
+ real_dtype=real_dtype,
1073
+ complex_dtype=complex_dtype,
1074
+ fftargs=kwargs.get("fftargs", {}),
1075
+ temp_real=numerator,
1076
+ temp_fft=temp_ft,
1077
+ )
1078
+
1079
+ filter_template = backend.size(template_filter) != 0
1080
+ template_filter_func = conditional_execute(backend.multiply, filter_template)
1081
+
1082
+ axis = tuple(range(template.ndim))
1083
+ fourier_shift = callback_class_args.get(
1084
+ "fourier_shift", backend.zeros(template.ndim)
1085
+ )
1086
+ fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
1087
+
1088
+ # Calculate scores across all rotations
1089
+ for index in range(rotations.shape[0]):
1090
+ rotation = rotations[index]
1091
+ backend.fill(template_rot, 0)
1092
+ backend.fill(temp, 0)
1093
+
1094
+ backend.rotate_array(
1095
+ arr=template,
1096
+ arr_mask=template_mask,
1097
+ rotation_matrix=rotation,
1098
+ out=template_rot,
1099
+ out_mask=temp,
1100
+ use_geometric_center=False,
1101
+ order=interpolation_order,
1102
+ )
1103
+
1104
+ backend.multiply(template_rot, temp > 0, out=template_rot)
1105
+
1106
+ # template_rot_ft
1107
+ rfftn(template_rot, temp_ft)
1108
+ template_filter_func(temp_ft, template_filter, out=temp_ft)
1109
+ irfftn(target_mask_ft * temp_ft, temp2)
1110
+ irfftn(target_ft * temp_ft, numerator)
1111
+
1112
+ # temp template_mask_rot | temp_ft template_mask_rot_ft
1113
+ # Calculate overlap of masks at every point in the convolution.
1114
+ # Locations with high overlap should not be taken into account.
1115
+ rfftn(temp, temp_ft)
1116
+ irfftn(temp_ft * target_mask_ft, mask_overlap)
1117
+ mask_overlap[:] = np.round(mask_overlap)
1118
+ mask_overlap[:] = np.maximum(mask_overlap, eps)
1119
+ irfftn(temp_ft * target_ft, temp)
1120
+
1121
+ backend.subtract(
1122
+ numerator,
1123
+ backend.divide(backend.multiply(temp, temp2), mask_overlap),
1124
+ out=numerator,
1125
+ )
1126
+
1127
+ # temp_3 = fixed_denom
1128
+ backend.multiply(temp_ft, target_ft2, out=temp_ft)
1129
+ irfftn(temp_ft, temp3)
1130
+ backend.subtract(
1131
+ temp3, backend.divide(backend.square(temp), mask_overlap), out=temp3
1132
+ )
1133
+ backend.maximum(temp3, 0.0, out=temp3)
1134
+
1135
+ # temp = moving_denom
1136
+ rfftn(backend.square(template_rot), temp_ft)
1137
+ backend.multiply(target_mask_ft, temp_ft, out=temp_ft)
1138
+ irfftn(temp_ft, temp)
1139
+
1140
+ backend.subtract(
1141
+ temp, backend.divide(backend.square(temp2), mask_overlap), out=temp
1142
+ )
1143
+ backend.maximum(temp, 0.0, out=temp)
1144
+
1145
+ # temp_2 = denom
1146
+ backend.multiply(temp3, temp, out=temp)
1147
+ backend.sqrt(temp, temp2)
1148
+
1149
+ # Pixels where `denom` is very small will introduce large
1150
+ # numbers after division. To get around this problem,
1151
+ # we zero-out problematic pixels.
1152
+ tol = 1e3 * eps * backend.max(backend.abs(temp2), axis=axes, keepdims=True)
1153
+ nonzero_indices = temp2 > tol
1154
+
1155
+ backend.fill(temp, 0)
1156
+ temp[nonzero_indices] = numerator[nonzero_indices] / temp2[nonzero_indices]
1157
+ backend.clip(temp, a_min=-1, a_max=1, out=temp)
1158
+
1159
+ # Apply overlap ratio threshold
1160
+ number_px_threshold = overlap_ratio * backend.max(
1161
+ mask_overlap, axis=axes, keepdims=True
1162
+ )
1163
+ temp[mask_overlap < number_px_threshold] = 0.0
1164
+ convolution_mode = kwargs.get("convolution_mode", "full")
1165
+
1166
+ if fourier_shift_scores:
1167
+ temp = backend.roll(temp, shift=fourier_shift, axis=axis)
1168
+
1169
+ score = apply_convolution_mode(
1170
+ temp, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
1171
+ )
1172
+ if callback_class is not None:
1173
+ callback(
1174
+ score,
1175
+ rotation_matrix=rotation,
1176
+ rotation_index=index,
1177
+ **callback_class_args,
1178
+ )
1179
+
1180
+ return callback
1181
+
1182
+
1183
+ def device_memory_handler(func: Callable):
1184
+ """Decorator function providing SharedMemory Handler."""
1185
+
1186
+ @wraps(func)
1187
+ def inner_function(*args, **kwargs):
1188
+ return_value = None
1189
+ last_type, last_value, last_traceback = sys.exc_info()
1190
+ try:
1191
+ with SharedMemoryManager() as smh:
1192
+ with backend.set_device(kwargs.get("gpu_index", 0)):
1193
+ return_value = func(shared_memory_handler=smh, *args, **kwargs)
1194
+ except Exception as e:
1195
+ print(e)
1196
+ last_type, last_value, last_traceback = sys.exc_info()
1197
+ finally:
1198
+ handle_traceback(last_type, last_value, last_traceback)
1199
+ return return_value
1200
+
1201
+ return inner_function
1202
+
1203
+
1204
+ @device_memory_handler
1205
+ def scan(
1206
+ matching_data: MatchingData,
1207
+ matching_setup: Callable,
1208
+ matching_score: Callable,
1209
+ n_jobs: int = 4,
1210
+ callback_class: CallbackClass = None,
1211
+ callback_class_args: Dict = {},
1212
+ fftargs: Dict = {},
1213
+ pad_fourier: bool = True,
1214
+ interpolation_order: int = 3,
1215
+ jobs_per_callback_class: int = 8,
1216
+ **kwargs,
1217
+ ) -> Tuple:
1218
+ """
1219
+ Perform template matching between target and template and sample
1220
+ different rotations of template.
1221
+
1222
+ Parameters
1223
+ ----------
1224
+ matching_data : MatchingData
1225
+ Template matching data.
1226
+ matching_setup : Callable
1227
+ Function pointer to setup function.
1228
+ matching_score : Callable
1229
+ Function pointer to scoring function.
1230
+ n_jobs : int, optional
1231
+ Number of parallel jobs. Default is 4.
1232
+ callback_class : type, optional
1233
+ Analyzer class pointer to operate on computed scores.
1234
+ callback_class_args : dict, optional
1235
+ Arguments passed to the callback_class. Default is an empty dictionary.
1236
+ fftargs : dict, optional
1237
+ Arguments for the FFT operations. Default is an empty dictionary.
1238
+ pad_fourier: bool, optional
1239
+ Whether to pad target and template to the full convolution shape.
1240
+ interpolation_order : int, optional
1241
+ Order of spline interpolation for rotations.
1242
+ jobs_per_callback_class : int, optional
1243
+ How many jobs should be processed by a single callback_class instance,
1244
+ if ones is provided.
1245
+ **kwargs : various
1246
+ Additional arguments.
1247
+
1248
+ Returns
1249
+ -------
1250
+ Tuple
1251
+ The merged results from callback_class if provided otherwise None.
1252
+ """
1253
+ matching_data.to_backend()
1254
+ fourier_pad = matching_data._templateshape
1255
+ fourier_shift = backend.zeros(len(fourier_pad))
1256
+ if not pad_fourier:
1257
+ fourier_pad = backend.full(shape=fourier_shift.shape, fill_value=1, dtype=int)
1258
+
1259
+ convolution_shape, fast_shape, fast_ft_shape = backend.compute_convolution_shapes(
1260
+ matching_data._target.shape, fourier_pad
1261
+ )
1262
+ if not pad_fourier:
1263
+ fourier_shift = 1 - backend.astype(
1264
+ backend.divide(matching_data._templateshape, 2), int
1265
+ )
1266
+ fourier_shift -= backend.mod(matching_data._templateshape, 2)
1267
+ fourier_shift = backend.flip(fourier_shift, axis=(0,))
1268
+ shape_diff = backend.subtract(fast_shape, convolution_shape)
1269
+ shape_diff = backend.astype(backend.divide(shape_diff, 2), int)
1270
+ backend.add(fourier_shift, shape_diff, out=fourier_shift)
1271
+
1272
+ callback_class_args["fourier_shift"] = fourier_shift
1273
+ rfftn, irfftn = backend.build_fft(
1274
+ fast_shape=fast_shape,
1275
+ fast_ft_shape=fast_ft_shape,
1276
+ real_dtype=matching_data._default_dtype,
1277
+ complex_dtype=matching_data._complex_dtype,
1278
+ fftargs=fftargs,
1279
+ )
1280
+ setup = matching_setup(
1281
+ rfftn=rfftn,
1282
+ irfftn=irfftn,
1283
+ template=matching_data.template,
1284
+ template_mask=matching_data.template_mask,
1285
+ target=matching_data.target,
1286
+ target_mask=matching_data.target_mask,
1287
+ fast_shape=fast_shape,
1288
+ fast_ft_shape=fast_ft_shape,
1289
+ real_dtype=matching_data._default_dtype,
1290
+ complex_dtype=matching_data._complex_dtype,
1291
+ callback_class=callback_class,
1292
+ callback_class_args=callback_class_args,
1293
+ **kwargs,
1294
+ )
1295
+ rfftn, irfftn = None, None
1296
+
1297
+ template_filter, preprocessor = None, Preprocessor()
1298
+ for method, parameters in matching_data.template_filter.items():
1299
+ parameters["shape"] = fast_shape
1300
+ parameters["omit_negative_frequencies"] = True
1301
+ out = preprocessor.apply_method(method=method, parameters=parameters)
1302
+ if template_filter is None:
1303
+ template_filter = out
1304
+ np.multiply(template_filter, out, out=template_filter)
1305
+
1306
+ if template_filter is None:
1307
+ template_filter = backend.full(
1308
+ shape=(1,), fill_value=1, dtype=backend._default_dtype
1309
+ )
1310
+ else:
1311
+ template_filter = backend.to_backend_array(template_filter)
1312
+
1313
+ template_filter = backend.astype(template_filter, backend._default_dtype)
1314
+ template_filter_buffer = backend.arr_to_sharedarr(
1315
+ arr=template_filter,
1316
+ shared_memory_handler=kwargs.get("shared_memory_handler", None),
1317
+ )
1318
+ setup["template_filter"] = (
1319
+ template_filter_buffer,
1320
+ template_filter.shape,
1321
+ template_filter.dtype,
1322
+ )
1323
+
1324
+ callback_class_args["translation_offset"] = backend.astype(
1325
+ matching_data._translation_offset, int
1326
+ )
1327
+ callback_class_args["thread_safe"] = n_jobs > 1
1328
+ callback_class_args["gpu_index"] = kwargs.get("gpu_index", -1)
1329
+
1330
+ n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
1331
+ callback_class = setup.pop("callback_class", callback_class)
1332
+ callback_class_args = setup.pop("callback_class_args", callback_class_args)
1333
+ callback_classes = [callback_class for _ in range(n_callback_classes)]
1334
+ if callback_class == MaxScoreOverRotations:
1335
+ score_space_shape = backend.subtract(
1336
+ matching_data.target.shape,
1337
+ matching_data._target_pad,
1338
+ )
1339
+ callback_classes = [
1340
+ class_name(
1341
+ score_space_shape=score_space_shape,
1342
+ score_space_dtype=matching_data._default_dtype,
1343
+ shared_memory_handler=kwargs.get("shared_memory_handler", None),
1344
+ rotation_space_dtype=backend._default_dtype_int,
1345
+ **callback_class_args,
1346
+ )
1347
+ for class_name in callback_classes
1348
+ ]
1349
+
1350
+ matching_data._target, matching_data._template = None, None
1351
+ matching_data._target_mask, matching_data._template_mask = None, None
1352
+
1353
+ setup["fftargs"] = fftargs.copy()
1354
+ convolution_mode = "same"
1355
+ if backend.sum(matching_data._target_pad) > 0:
1356
+ convolution_mode = "valid"
1357
+ setup["convolution_mode"] = convolution_mode
1358
+ setup["interpolation_order"] = interpolation_order
1359
+ rotation_list = matching_data._split_rotations_on_jobs(n_jobs)
1360
+
1361
+ backend.free_cache()
1362
+
1363
+ def _run_scoring(backend_name, backend_args, rotations, **kwargs):
1364
+ from tme.backends import backend
1365
+
1366
+ backend.change_backend(backend_name, **backend_args)
1367
+ return matching_score(rotations=rotations, **kwargs)
1368
+
1369
+ callbacks = Parallel(n_jobs=n_jobs)(
1370
+ delayed(_run_scoring)(
1371
+ backend_name=backend._backend_name,
1372
+ backend_args=backend._backend_args,
1373
+ rotations=rotation,
1374
+ callback_class=callback_classes[index % n_callback_classes],
1375
+ callback_class_args=callback_class_args,
1376
+ **setup,
1377
+ )
1378
+ for index, rotation in enumerate(rotation_list)
1379
+ )
1380
+
1381
+ callbacks = [
1382
+ tuple(callback)
1383
+ for callback in callbacks[0:n_callback_classes]
1384
+ if callback is not None
1385
+ ]
1386
+ backend.free_cache()
1387
+
1388
+ merged_callback = None
1389
+ if callback_class is not None:
1390
+ merged_callback = callback_class.merge(
1391
+ callbacks,
1392
+ **callback_class_args,
1393
+ score_indices=matching_data.indices,
1394
+ inner_merge=True,
1395
+ )
1396
+
1397
+ return merged_callback
1398
+
1399
+
1400
+ def scan_subsets(
1401
+ matching_data: MatchingData,
1402
+ matching_score: Callable,
1403
+ matching_setup: Callable,
1404
+ callback_class: CallbackClass = None,
1405
+ callback_class_args: Dict = {},
1406
+ job_schedule: Tuple[int] = (1, 1),
1407
+ target_splits: Dict = {},
1408
+ template_splits: Dict = {},
1409
+ pad_target_edges: bool = False,
1410
+ pad_fourier: bool = True,
1411
+ interpolation_order: int = 3,
1412
+ jobs_per_callback_class: int = 8,
1413
+ **kwargs,
1414
+ ) -> Tuple:
1415
+ """
1416
+ Wrapper around :py:meth:`scan` that supports template matching on splits
1417
+ of template and target.
1418
+
1419
+ Parameters
1420
+ ----------
1421
+ matching_data : MatchingData
1422
+ Template matching data.
1423
+ matching_func : type
1424
+ Function pointer to setup function.
1425
+ matching_score : type
1426
+ Function pointer to scoring function.
1427
+ callback_class : type, optional
1428
+ Analyzer class pointer to operate on computed scores.
1429
+ callback_class_args : dict, optional
1430
+ Arguments passed to the callback_class. Default is an empty dictionary.
1431
+ job_schedule : tuple of int, optional
1432
+ Schedule of jobs. Default is (1, 1).
1433
+ target_splits : dict, optional
1434
+ Splits for target. Default is an empty dictionary, i.e. no splits
1435
+ template_splits : dict, optional
1436
+ Splits for template. Default is an empty dictionary, i.e. no splits.
1437
+ pad_target_edges : bool, optional
1438
+ Whether to pad the target boundaries by half the template shape
1439
+ along each axis.
1440
+ pad_fourier: bool, optional
1441
+ Whether to pad target and template to the full convolution shape.
1442
+ interpolation_order : int, optional
1443
+ Order of spline interpolation for rotations.
1444
+ jobs_per_callback_class : int, optional
1445
+ How many jobs should be processed by a single callback_class instance,
1446
+ if ones is provided.
1447
+ **kwargs : various
1448
+ Additional arguments.
1449
+
1450
+ Notes
1451
+ -----
1452
+ Objects in matching_data might be destroyed during computation.
1453
+
1454
+ Returns
1455
+ -------
1456
+ Tuple
1457
+ The merged results from callback_class if provided otherwise None.
1458
+ """
1459
+ target_splits = split_numpy_array_slices(
1460
+ matching_data.target.shape, splits=target_splits
1461
+ )
1462
+ template_splits = split_numpy_array_slices(
1463
+ matching_data._templateshape, splits=template_splits
1464
+ )
1465
+
1466
+ target_pad = np.zeros(len(matching_data.target.shape), dtype=int)
1467
+ if pad_target_edges:
1468
+ target_pad = np.subtract(
1469
+ matching_data._templateshape, np.mod(matching_data._templateshape, 2)
1470
+ )
1471
+ outer_jobs, inner_jobs = job_schedule
1472
+ results = Parallel(n_jobs=outer_jobs)(
1473
+ delayed(_run_inner)(
1474
+ backend_name=backend._backend_name,
1475
+ backend_args=backend._backend_args,
1476
+ matching_data=matching_data.subset_by_slice(
1477
+ target_slice=target_split,
1478
+ target_pad=target_pad,
1479
+ template_slice=template_split,
1480
+ ),
1481
+ matching_score=matching_score,
1482
+ matching_setup=matching_setup,
1483
+ n_jobs=inner_jobs,
1484
+ callback_class=callback_class,
1485
+ callback_class_args=callback_class_args,
1486
+ interpolation_order=interpolation_order,
1487
+ pad_fourier=pad_fourier,
1488
+ gpu_index=index % outer_jobs,
1489
+ **kwargs,
1490
+ )
1491
+ for index, (target_split, template_split) in enumerate(
1492
+ product(target_splits, template_splits)
1493
+ )
1494
+ )
1495
+
1496
+ matching_data._target, matching_data._template = None, None
1497
+ matching_data._target_mask, matching_data._template_mask = None, None
1498
+
1499
+ if callback_class is not None:
1500
+ candidates = callback_class.merge(
1501
+ results, **callback_class_args, inner_merge=False
1502
+ )
1503
+ return candidates
1504
+
1505
+
1506
+ MATCHING_EXHAUSTIVE_REGISTER = {
1507
+ "CC": (cc_setup, corr_scoring),
1508
+ "LCC": (lcc_setup, corr_scoring),
1509
+ "CORR": (corr_setup, corr_scoring),
1510
+ "CAM": (cam_setup, corr_scoring),
1511
+ "FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
1512
+ "FLC": (flc_setup, flc_scoring),
1513
+ "MCC": (mcc_setup, mcc_scoring),
1514
+ }
1515
+
1516
+
1517
+ def register_matching_exhaustive(
1518
+ matching: str,
1519
+ matching_setup: Callable,
1520
+ matching_scoring: Callable,
1521
+ memory_class: MatchingMemoryUsage,
1522
+ ) -> None:
1523
+ """
1524
+ Registers a new matching scheme.
1525
+
1526
+ Parameters
1527
+ ----------
1528
+ matching : str
1529
+ Name of the matching method.
1530
+ matching_setup : Callable
1531
+ The setup function associated with the name.
1532
+ matching_scoring : Callable
1533
+ The scoring function associated with the name.
1534
+ memory_class : MatchingMemoryUsage
1535
+ The custom memory estimation class extending
1536
+ :py:class:`tme.matching_memory.MatchingMemoryUsage`.
1537
+
1538
+ Raises
1539
+ ------
1540
+ ValueError
1541
+ If a function with the name ``matching`` already exists in the registry.
1542
+ ValueError
1543
+ If ``memory_class`` is not a subclass of
1544
+ :py:class:`tme.matching_memory.MatchingMemoryUsage`.
1545
+ """
1546
+
1547
+ if matching in MATCHING_EXHAUSTIVE_REGISTER:
1548
+ raise ValueError(f"A method with name '{matching}' is already registered.")
1549
+ if not issubclass(memory_class, MatchingMemoryUsage):
1550
+ raise ValueError(f"{memory_class} is not a subclass of {MatchingMemoryUsage}.")
1551
+
1552
+ MATCHING_EXHAUSTIVE_REGISTER[matching] = (matching_setup, matching_scoring)
1553
+ MATCHING_MEMORY_REGISTRY[matching] = memory_class