pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
tme/matching_scores.py ADDED
@@ -0,0 +1,1183 @@
1
+ """ Implements a range of cross-correlation coefficients.
2
+
3
+ Copyright (c) 2023-2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ import warnings
9
+ from typing import Callable, Tuple, Dict
10
+
11
+ import numpy as np
12
+ from scipy.ndimage import laplace
13
+
14
+ from .backends import backend as be
15
+ from .types import CallbackClass, BackendArray, shm_type
16
+ from .matching_utils import (
17
+ conditional_execute,
18
+ identity,
19
+ normalize_template,
20
+ _normalize_template_overflow_safe,
21
+ )
22
+
23
+
24
+ def _shape_match(shape1: Tuple[int], shape2: Tuple[int]) -> bool:
25
+ """
26
+ Determine whether ``shape1`` is equal to ``shape2``.
27
+
28
+ Parameters
29
+ ----------
30
+ shape1, shape2 : tuple of ints
31
+ Shapes to compare.
32
+
33
+ Returns
34
+ -------
35
+ Bool
36
+ ``shape1`` is equal to ``shape2``.
37
+ """
38
+ if len(shape1) != len(shape2):
39
+ return False
40
+ return shape1 == shape2
41
+
42
+
43
+ def _create_filter_func(
44
+ fwd_shape: Tuple[int],
45
+ inv_shape: Tuple[int],
46
+ arr_shape: Tuple[int],
47
+ arr_filter: BackendArray,
48
+ arr_ft_shape: Tuple[int],
49
+ inv_output_shape: Tuple[int],
50
+ real_dtype: type,
51
+ cmpl_dtype: type,
52
+ fwd_axes=None,
53
+ inv_axes=None,
54
+ rfftn: Callable = None,
55
+ irfftn: Callable = None,
56
+ ) -> Callable:
57
+ """
58
+ Configure template filtering function for Fourier transforms.
59
+
60
+ Conceptually we distinguish between three cases. The base case
61
+ is that both template and the corresponding filter have the same
62
+ shape. Padding is used when the template filter is larger than
63
+ the template, for instance to better resolve Fourier filters. Finally
64
+ this function also handles the case when a filter is supposed to be
65
+ broadcasted over the template batch dimension.
66
+
67
+ Parameters
68
+ ----------
69
+ fwd_shape : tuple of ints
70
+ Input shape of rfftn.
71
+ inv_shape : tuple of ints
72
+ Input shape of irfftn.
73
+ arr_shape : tuple of ints
74
+ Shape of the array to be filtered.
75
+ arr_ft_shape : tuple of ints
76
+ Shape of the Fourier transform of the array.
77
+ arr_filter : BackendArray
78
+ Precomputed filter to apply in the frequency domain.
79
+ rfftn : Callable, optional
80
+ Foward Fourier transform.
81
+ irfftn : Callable, optional
82
+ Inverse Fourier transform.
83
+
84
+ Returns
85
+ -------
86
+ Callable
87
+ Filter function with parameters template, ft_temp and template_filter.
88
+ """
89
+ if be.size(arr_filter) == 1:
90
+ return conditional_execute(identity, identity, False)
91
+
92
+ filter_shape = tuple(int(x) for x in arr_filter.shape)
93
+ try:
94
+ product_ft_shape = np.broadcast_shapes(arr_ft_shape, filter_shape)
95
+ except ValueError:
96
+ product_ft_shape, inv_output_shape = filter_shape, arr_shape
97
+
98
+ rfft_valid = _shape_match(arr_shape, fwd_shape)
99
+ rfft_valid = rfft_valid and _shape_match(product_ft_shape, inv_shape)
100
+ rfft_valid = rfft_valid and rfftn is not None and irfftn is not None
101
+
102
+ # FTTs were not or built for the wrong shape
103
+ if not rfft_valid:
104
+ _fwd_shape = arr_shape
105
+ if all(x > y for x, y in zip(arr_shape, product_ft_shape)):
106
+ _fwd_shape = fwd_shape
107
+
108
+ rfftn, irfftn = be.build_fft(
109
+ fwd_shape=_fwd_shape,
110
+ inv_shape=product_ft_shape,
111
+ real_dtype=real_dtype,
112
+ cmpl_dtype=cmpl_dtype,
113
+ inv_output_shape=inv_output_shape,
114
+ fwd_axes=fwd_axes,
115
+ inv_axes=inv_axes,
116
+ )
117
+
118
+ # Default case, all shapes are correctly matched
119
+ def _apply_filter(template, ft_temp, template_filter):
120
+ ft_temp = rfftn(template, ft_temp)
121
+ ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
122
+ return irfftn(ft_temp, template)
123
+
124
+ if not _shape_match(arr_ft_shape, filter_shape):
125
+ real_subset = tuple(slice(0, x) for x in arr_shape)
126
+ _template = be.zeros(arr_shape, be._float_dtype)
127
+ _ft_temp = be.zeros(product_ft_shape, be._complex_dtype)
128
+
129
+ # Arr is padded, filter is not
130
+ def _apply_filter_subset(template, ft_temp, template_filter):
131
+ # TODO: Benchmark this
132
+ _template[:] = template[real_subset]
133
+ template[real_subset] = _apply_filter(_template, _ft_temp, template_filter)
134
+ return template
135
+
136
+ # Filter application requires a broadcasting operation
137
+ def _apply_filter_broadcast(template, ft_temp, template_filter):
138
+ _ft_prod = rfftn(template, _ft_temp2)
139
+ _ft_res = be.multiply(_ft_prod, template_filter, out=_ft_temp)
140
+ return irfftn(_ft_res, _template)
141
+
142
+ if any(x > y and y == 1 for x, y in zip(filter_shape, arr_ft_shape)):
143
+ _template = be.zeros(inv_output_shape, be._float_dtype)
144
+ _ft_temp2 = be.zeros((1, *product_ft_shape[1:]), be._complex_dtype)
145
+ return _apply_filter_broadcast
146
+
147
+ return _apply_filter_subset
148
+
149
+ return _apply_filter
150
+
151
+
152
+ def cc_setup(
153
+ matching_data: type,
154
+ fast_shape: Tuple[int],
155
+ fast_ft_shape: Tuple[int],
156
+ shm_handler: type,
157
+ **kwargs,
158
+ ) -> Dict:
159
+ """
160
+ Setup function for computing the unnormalized cross-correlation between
161
+ ``target`` (f) and ``template`` (g)
162
+
163
+ .. math::
164
+
165
+ \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*).
166
+
167
+ Notes
168
+ -----
169
+ To be used with :py:meth:`corr_scoring`.
170
+ """
171
+ target_pad = be.topleft_pad(
172
+ matching_data.target,
173
+ matching_data._batch_shape(fast_shape, matching_data._template_batch),
174
+ )
175
+ axes = matching_data._batch_axis(matching_data._batch_mask)
176
+
177
+ ret = {
178
+ "fast_shape": fast_shape,
179
+ "fast_ft_shape": fast_ft_shape,
180
+ "template": be.to_sharedarr(matching_data.template, shm_handler),
181
+ "ft_target": be.to_sharedarr(be.rfftn(target_pad, axes=axes), shm_handler),
182
+ "inv_denominator": be.to_sharedarr(
183
+ be.zeros(1, be._float_dtype) + 1, shm_handler
184
+ ),
185
+ "numerator": be.to_sharedarr(be.zeros(1, be._float_dtype), shm_handler),
186
+ }
187
+
188
+ return ret
189
+
190
+
191
+ def lcc_setup(matching_data, **kwargs) -> Dict:
192
+ """
193
+ Setup function for computing the laplace cross-correlation between
194
+ ``target`` (f) and ``template`` (g)
195
+
196
+ .. math::
197
+
198
+ \\mathcal{F}^{-1}(\\mathcal{F}(\\nabla^{2}f) \\cdot \\mathcal{F}(\\nabla^{2} g)^*)
199
+
200
+ Notes
201
+ -----
202
+ To be used with :py:meth:`corr_scoring`.
203
+ """
204
+ target = be.to_numpy_array(matching_data._target)
205
+ template = be.to_numpy_array(matching_data._template)
206
+
207
+ subsets = matching_data._batch_iter(
208
+ target.shape,
209
+ tuple(1 if i in matching_data._target_dim else 0 for i in range(target.ndim)),
210
+ )
211
+ for subset in subsets:
212
+ target[subset] = laplace(target[subset], mode="wrap")
213
+
214
+ subsets = matching_data._batch_iter(
215
+ template.shape,
216
+ tuple(1 if i in matching_data._template_dim else 0 for i in range(target.ndim)),
217
+ )
218
+ for subset in subsets:
219
+ template[subset] = laplace(template[subset], mode="wrap")
220
+
221
+ matching_data._target = target
222
+ matching_data._template = template
223
+
224
+ return cc_setup(matching_data=matching_data, **kwargs)
225
+
226
+
227
+ def corr_setup(
228
+ matching_data,
229
+ template_filter,
230
+ fast_shape: Tuple[int],
231
+ fast_ft_shape: Tuple[int],
232
+ shm_handler: type,
233
+ **kwargs,
234
+ ) -> Dict:
235
+ """
236
+ Setup for computing a normalized cross-correlation between a
237
+ ``target`` (f), a ``template`` (g) given ``template_mask`` (m)
238
+
239
+ .. math::
240
+
241
+ \\frac{CC(f,g) - \\overline{g} \\cdot CC(f, m)}
242
+ {(CC(f^2, m) - \\frac{CC(f, m)^2}{N_g}) \\cdot \\sigma_{g}},
243
+
244
+ where
245
+
246
+ .. math::
247
+
248
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*).
249
+
250
+ Notes
251
+ -----
252
+ To be used with :py:meth:`corr_scoring`.
253
+
254
+ References
255
+ ----------
256
+ .. [1] Lewis P. J. Fast Normalized Cross-Correlation, Industrial Light and Magic.
257
+ """
258
+ template, template_mask = matching_data.template, matching_data.template_mask
259
+ target_pad = be.topleft_pad(
260
+ matching_data.target,
261
+ matching_data._batch_shape(fast_shape, matching_data._template_batch),
262
+ )
263
+ data_axes = matching_data._batch_axis(matching_data._batch_mask)
264
+ data_shape = tuple(fast_shape[i] for i in data_axes)
265
+
266
+ ft_window = be.rfftn(be.topleft_pad(template_mask, fast_shape), axes=data_axes)
267
+
268
+ ft_target = be.rfftn(be.square(target_pad), axes=data_axes)
269
+ ft_target = be.multiply(ft_target, ft_window)
270
+ denominator = be.irfftn(ft_target, s=data_shape, axes=data_axes)
271
+
272
+ ft_target = be.rfftn(target_pad, axes=data_axes)
273
+ ft_window = be.multiply(ft_target, ft_window)
274
+ window_sum = be.irfftn(ft_window, s=data_shape, axes=data_axes)
275
+
276
+ target_pad, ft_window = None, None
277
+
278
+ # TODO: Factor in template_filter here
279
+ if be.size(template_filter) != 1:
280
+ warnings.warn(
281
+ "CORR scores obtained with template_filter are not correctly scaled. "
282
+ "Please use a different score or consider only relative peak heights."
283
+ )
284
+ axis = matching_data._batch_axis(matching_data._template_batch)
285
+ n_obs = be.sum(
286
+ be.astype(template_mask, be._overflow_safe_dtype), axis=axis, keepdims=True
287
+ )
288
+ template_mean = be.multiply(template, template_mask)
289
+ template_mean = be.sum(template_mean, axis=axis, keepdims=True)
290
+ template_mean = be.divide(template_mean, n_obs)
291
+ template_ssd = be.square(template - template_mean) * template_mask
292
+ template_ssd = be.sum(template_ssd, axis=axis, keepdims=True)
293
+
294
+ template_volume = np.prod(
295
+ tuple(
296
+ int(x)
297
+ for i, x in enumerate(template.shape)
298
+ if matching_data._template_batch[i] == 0
299
+ )
300
+ )
301
+ template = be.multiply(template, template_mask, out=template)
302
+
303
+ numerator = be.multiply(window_sum, template_mean)
304
+ window_sum = be.square(window_sum, out=window_sum)
305
+ window_sum = be.divide(window_sum, template_volume, out=window_sum)
306
+ denominator = be.subtract(denominator, window_sum, out=denominator)
307
+ denominator = be.multiply(denominator, template_ssd, out=denominator)
308
+ denominator = be.maximum(denominator, 0, out=denominator)
309
+ denominator = be.sqrt(denominator, out=denominator)
310
+
311
+ mask = denominator > be.eps(be._float_dtype)
312
+ denominator = be.multiply(denominator, mask, out=denominator)
313
+ denominator = be.add(denominator, ~mask, out=denominator)
314
+ denominator = be.divide(1, denominator, out=denominator)
315
+ denominator = be.multiply(denominator, mask, out=denominator)
316
+
317
+ ret = {
318
+ "fast_shape": fast_shape,
319
+ "fast_ft_shape": fast_ft_shape,
320
+ "template": be.to_sharedarr(template, shm_handler),
321
+ "ft_target": be.to_sharedarr(ft_target, shm_handler),
322
+ "inv_denominator": be.to_sharedarr(denominator, shm_handler),
323
+ "numerator": be.to_sharedarr(numerator, shm_handler),
324
+ }
325
+
326
+ return ret
327
+
328
+
329
+ def cam_setup(matching_data, **kwargs) -> Dict:
330
+ """
331
+ Like :py:meth:`corr_setup` but with standardized ``target``, ``template``
332
+
333
+ .. math::
334
+
335
+ f' = \\frac{f - \\overline{f}}{\\sigma_f}.
336
+
337
+ Notes
338
+ -----
339
+ To be used with :py:meth:`corr_scoring`.
340
+ """
341
+ template = matching_data._template
342
+ axis = matching_data._batch_axis(matching_data._target_batch)
343
+ matching_data._template = be.divide(
344
+ be.subtract(template, be.mean(template, axis=axis, keepdims=True)),
345
+ be.std(template, axis=axis, keepdims=True),
346
+ )
347
+ target = matching_data._target
348
+ axis = matching_data._batch_axis(matching_data._template_batch)
349
+ matching_data._target = be.divide(
350
+ be.subtract(target, be.mean(target, axis=axis, keepdims=True)),
351
+ be.std(target, axis=axis, keepdims=True),
352
+ )
353
+ return corr_setup(matching_data=matching_data, **kwargs)
354
+
355
+
356
+ def flc_setup(
357
+ matching_data,
358
+ fast_shape: Tuple[int],
359
+ fast_ft_shape: Tuple[int],
360
+ shm_handler: type,
361
+ **kwargs,
362
+ ) -> Dict:
363
+ """
364
+ Setup function for :py:meth:`flc_scoring`.
365
+ """
366
+ target_pad = be.topleft_pad(
367
+ matching_data.target,
368
+ matching_data._batch_shape(fast_shape, matching_data._template_batch),
369
+ )
370
+
371
+ data_axes = matching_data._batch_axis(matching_data._batch_mask)
372
+
373
+ ft_target = be.rfftn(target_pad, axes=data_axes)
374
+ target_pad = be.square(target_pad, out=target_pad)
375
+ ft_target2 = be.rfftn(target_pad, axes=data_axes)
376
+
377
+ ret = {
378
+ "fast_shape": fast_shape,
379
+ "fast_ft_shape": fast_ft_shape,
380
+ "template": be.to_sharedarr(matching_data.template, shm_handler),
381
+ "template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
382
+ "ft_target": be.to_sharedarr(ft_target, shm_handler),
383
+ "ft_target2": be.to_sharedarr(ft_target2, shm_handler),
384
+ }
385
+
386
+ return ret
387
+
388
+
389
+ def flcSphericalMask_setup(
390
+ matching_data,
391
+ fast_shape: Tuple[int],
392
+ fast_ft_shape: Tuple[int],
393
+ shm_handler: type,
394
+ **kwargs,
395
+ ) -> Dict:
396
+ """
397
+ Like :py:meth:`flc_setup` for rotation invariant masks
398
+
399
+ Notes
400
+ -----
401
+ To be used with :py:meth:`corr_scoring`.
402
+ """
403
+ template_mask = matching_data.template_mask
404
+ axis = matching_data._batch_axis(matching_data._template_batch)
405
+ n_obs = be.sum(
406
+ be.astype(template_mask, be._overflow_safe_dtype), axis=axis, keepdims=True
407
+ )
408
+
409
+ target_pad = be.topleft_pad(
410
+ matching_data.target,
411
+ matching_data._batch_shape(fast_shape, matching_data._template_batch),
412
+ )
413
+
414
+ # Enable mask broadcasting
415
+ _out_shape = tuple(
416
+ y if i in axis else x
417
+ for i, (x, y) in enumerate(zip(template_mask.shape, fast_shape))
418
+ )
419
+ template_mask_pad = be.topleft_pad(
420
+ template_mask,
421
+ matching_data._batch_shape(_out_shape, matching_data._target_batch),
422
+ )
423
+
424
+ data_axes = matching_data._batch_axis(matching_data._batch_mask)
425
+ data_shape = tuple(fast_shape[i] for i in data_axes)
426
+
427
+ ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
428
+ ft_template_mask = be.rfftn(template_mask_pad, s=data_shape, axes=data_axes)
429
+
430
+ ft_target = be.rfftn(be.square(target_pad), axes=data_axes)
431
+ ft_temp = be.multiply(ft_target, ft_template_mask, out=ft_temp)
432
+ temp2 = be.irfftn(ft_temp, s=data_shape, axes=data_axes)
433
+
434
+ ft_target = be.rfftn(target_pad, axes=data_axes)
435
+ ft_temp = be.multiply(ft_target, ft_template_mask, out=ft_temp)
436
+ temp = be.irfftn(ft_temp, s=data_shape, axes=data_axes)
437
+
438
+ temp2 = be.norm_scores(1, temp2, temp, n_obs, be.eps(be._float_dtype), temp2)
439
+ ret = {
440
+ "fast_shape": fast_shape,
441
+ "fast_ft_shape": fast_ft_shape,
442
+ "template": be.to_sharedarr(matching_data.template, shm_handler),
443
+ "template_mask": be.to_sharedarr(template_mask, shm_handler),
444
+ "ft_target": be.to_sharedarr(ft_target, shm_handler),
445
+ "inv_denominator": be.to_sharedarr(temp2, shm_handler),
446
+ "numerator": be.to_sharedarr(be.zeros(1, be._float_dtype), shm_handler),
447
+ }
448
+
449
+ return ret
450
+
451
+
452
+ def mcc_setup(
453
+ matching_data,
454
+ fast_shape: Tuple[int],
455
+ fast_ft_shape: Tuple[int],
456
+ shm_handler: Callable,
457
+ **kwargs,
458
+ ) -> Dict:
459
+ """
460
+ Setup function for :py:meth:`mcc_scoring`.
461
+ """
462
+ target, target_mask = matching_data.target, matching_data.target_mask
463
+ target = be.multiply(target, target_mask > 0, out=target)
464
+
465
+ target = be.topleft_pad(
466
+ target,
467
+ matching_data._batch_shape(fast_shape, matching_data._template_batch),
468
+ )
469
+ target_mask = be.topleft_pad(
470
+ target_mask,
471
+ matching_data._batch_shape(fast_shape, matching_data._template_batch),
472
+ )
473
+ ax = matching_data._batch_axis(matching_data._batch_mask)
474
+
475
+ ret = {
476
+ "fast_shape": fast_shape,
477
+ "fast_ft_shape": fast_ft_shape,
478
+ "template": be.to_sharedarr(matching_data.template, shm_handler),
479
+ "template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
480
+ "ft_target": be.to_sharedarr(be.rfftn(target, axes=ax), shm_handler),
481
+ "ft_target2": be.to_sharedarr(
482
+ be.rfftn(be.square(target), axes=ax), shm_handler
483
+ ),
484
+ "ft_target_mask": be.to_sharedarr(be.rfftn(target_mask, axes=ax), shm_handler),
485
+ }
486
+
487
+ return ret
488
+
489
+
490
+ def corr_scoring(
491
+ template: shm_type,
492
+ template_filter: shm_type,
493
+ ft_target: shm_type,
494
+ inv_denominator: shm_type,
495
+ numerator: shm_type,
496
+ fast_shape: Tuple[int],
497
+ fast_ft_shape: Tuple[int],
498
+ rotations: BackendArray,
499
+ callback: CallbackClass,
500
+ interpolation_order: int,
501
+ template_mask: shm_type = None,
502
+ ) -> CallbackClass:
503
+ """
504
+ Calculates a normalized cross-correlation between a target f and a template g.
505
+
506
+ .. math::
507
+
508
+ (CC(f,g) - \\text{numerator}) \\cdot \\text{inv_denominator},
509
+
510
+ where
511
+
512
+ .. math::
513
+
514
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*).
515
+
516
+ Parameters
517
+ ----------
518
+ template : Union[Tuple[type, tuple of ints, type], BackendArray]
519
+ Template data buffer, its shape and datatype.
520
+ template_filter : Union[Tuple[type, tuple of ints, type], BackendArray]
521
+ Template filter data buffer, its shape and datatype.
522
+ ft_target : Union[Tuple[type, tuple of ints, type], BackendArray]
523
+ Fourier transformed target data buffer, its shape and datatype.
524
+ inv_denominator : Union[Tuple[type, tuple of ints, type], BackendArray]
525
+ Inverse denominator data buffer, its shape and datatype.
526
+ numerator : Union[Tuple[type, tuple of ints, type], BackendArray]
527
+ Numerator data buffer, its shape, and its datatype.
528
+ fast_shape: tuple of ints
529
+ Data shape for the forward Fourier transform.
530
+ fast_ft_shape: tuple of ints
531
+ Data shape for the inverse Fourier transform.
532
+ rotations : BackendArray
533
+ Rotation matrices to be sampled (n, d, d).
534
+ callback : CallbackClass
535
+ A callable for processing the result of each rotation.
536
+ interpolation_order : int
537
+ Spline order for template rotations.
538
+ template_mask : Union[Tuple[type, tuple of ints, type], BackendArray], optional
539
+ Template mask data buffer, its shape and datatype, None by default.
540
+
541
+ Returns
542
+ -------
543
+ Optional[CallbackClass]
544
+ ``callback`` if provided otherwise None.
545
+ """
546
+ template = be.from_sharedarr(template)
547
+ ft_target = be.from_sharedarr(ft_target)
548
+ inv_denominator = be.from_sharedarr(inv_denominator)
549
+ numerator = be.from_sharedarr(numerator)
550
+ template_filter = be.from_sharedarr(template_filter)
551
+
552
+ norm_func, norm_template, mask_sum = normalize_template, False, 1
553
+ if template_mask is not None:
554
+ template_mask = be.from_sharedarr(template_mask)
555
+ norm_template, mask_sum = True, be.sum(template_mask)
556
+ if be.datatype_bytes(template_mask.dtype) == 2:
557
+ norm_func = _normalize_template_overflow_safe
558
+ mask_sum = be.sum(be.astype(template_mask, be._overflow_safe_dtype))
559
+
560
+ callback_func = conditional_execute(callback, callback is not None)
561
+ norm_template = conditional_execute(norm_func, norm_template)
562
+ norm_numerator = conditional_execute(
563
+ be.subtract, identity, _shape_match(numerator.shape, fast_shape)
564
+ )
565
+ norm_denominator = conditional_execute(
566
+ be.multiply, identity, _shape_match(inv_denominator.shape, fast_shape)
567
+ )
568
+
569
+ arr = be.zeros(fast_shape, be._float_dtype)
570
+ ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
571
+
572
+ _fftargs = {
573
+ "real_dtype": be._float_dtype,
574
+ "cmpl_dtype": be._complex_dtype,
575
+ "inv_output_shape": fast_shape,
576
+ "fwd_axes": None,
577
+ "inv_axes": None,
578
+ "inv_shape": fast_ft_shape,
579
+ "temp_fwd": arr,
580
+ }
581
+
582
+ _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
583
+ rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
584
+ _ = _fftargs.pop("temp_fwd", None)
585
+
586
+ template_filter_func = _create_filter_func(
587
+ arr_shape=template.shape,
588
+ arr_ft_shape=fast_ft_shape,
589
+ arr_filter=template_filter,
590
+ rfftn=rfftn,
591
+ irfftn=irfftn,
592
+ **_fftargs,
593
+ )
594
+
595
+ unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
596
+ for index in range(rotations.shape[0]):
597
+ rotation = rotations[index]
598
+ arr = be.fill(arr, 0)
599
+ arr, _ = be.rigid_transform(
600
+ arr=template,
601
+ rotation_matrix=rotation,
602
+ out=arr,
603
+ use_geometric_center=True,
604
+ order=interpolation_order,
605
+ cache=False,
606
+ )
607
+ arr = template_filter_func(arr, ft_temp, template_filter)
608
+ norm_template(arr[unpadded_slice], template_mask, mask_sum)
609
+
610
+ ft_temp = rfftn(arr, ft_temp)
611
+ ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp)
612
+ arr = irfftn(ft_temp, arr)
613
+
614
+ arr = norm_numerator(arr, numerator, out=arr)
615
+ arr = norm_denominator(arr, inv_denominator, out=arr)
616
+ callback_func(arr, rotation_matrix=rotation)
617
+
618
+ return callback
619
+
620
+
621
+ def flc_scoring(
622
+ template: shm_type,
623
+ template_mask: shm_type,
624
+ ft_target: shm_type,
625
+ ft_target2: shm_type,
626
+ template_filter: shm_type,
627
+ fast_shape: Tuple[int],
628
+ fast_ft_shape: Tuple[int],
629
+ rotations: BackendArray,
630
+ callback: CallbackClass,
631
+ interpolation_order: int,
632
+ ) -> CallbackClass:
633
+ """
634
+ Computes a normalized cross-correlation between ``target`` (f),
635
+ ``template`` (g), and ``template_mask`` (m)
636
+
637
+ .. math::
638
+
639
+ \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
640
+ {N_m * \\sqrt{
641
+ \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
642
+ },
643
+
644
+ where
645
+
646
+ .. math::
647
+
648
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
649
+
650
+ and Nm is the sum of g.
651
+
652
+ Parameters
653
+ ----------
654
+ template : Union[Tuple[type, tuple of ints, type], BackendArray]
655
+ Template data buffer, its shape and datatype.
656
+ template_mask : Union[Tuple[type, tuple of ints, type], BackendArray]
657
+ Template mask data buffer, its shape and datatype.
658
+ template_filter : Union[Tuple[type, tuple of ints, type], BackendArray]
659
+ Template filter data buffer, its shape and datatype.
660
+ ft_target : Union[Tuple[type, tuple of ints, type], BackendArray]
661
+ Fourier transformed target data buffer, its shape and datatype.
662
+ ft_target2 : Union[Tuple[type, tuple of ints, type], BackendArray]
663
+ Fourier transformed squared target data buffer, its shape and datatype.
664
+ fast_shape : tuple of ints
665
+ Data shape for the forward Fourier transform.
666
+ fast_ft_shape : tuple of ints
667
+ Data shape for the inverse Fourier transform.
668
+ rotations : BackendArray
669
+ Rotation matrices to be sampled (n, d, d).
670
+ callback : CallbackClass
671
+ A callable for processing the result of each rotation.
672
+ callback_class_args : Dict
673
+ Dictionary of arguments to be passed to ``callback``.
674
+ interpolation_order : int
675
+ Spline order for template rotations.
676
+
677
+ References
678
+ ----------
679
+ .. [1] Hrabe T. et al, J. Struct. Biol. 178, 177 (2012).
680
+ """
681
+ float_dtype, complex_dtype = be._float_dtype, be._complex_dtype
682
+ template = be.from_sharedarr(template)
683
+ template_mask = be.from_sharedarr(template_mask)
684
+ ft_target = be.from_sharedarr(ft_target)
685
+ ft_target2 = be.from_sharedarr(ft_target2)
686
+ template_filter = be.from_sharedarr(template_filter)
687
+
688
+ arr = be.zeros(fast_shape, float_dtype)
689
+ temp = be.zeros(fast_shape, float_dtype)
690
+ temp2 = be.zeros(fast_shape, float_dtype)
691
+ ft_temp = be.zeros(fast_ft_shape, complex_dtype)
692
+ ft_denom = be.zeros(fast_ft_shape, complex_dtype)
693
+
694
+ _fftargs = {
695
+ "real_dtype": be._float_dtype,
696
+ "cmpl_dtype": be._complex_dtype,
697
+ "inv_output_shape": fast_shape,
698
+ "fwd_axes": None,
699
+ "inv_axes": None,
700
+ "inv_shape": fast_ft_shape,
701
+ "temp_fwd": arr,
702
+ }
703
+
704
+ _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
705
+ rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
706
+ _ = _fftargs.pop("temp_fwd", None)
707
+
708
+ template_filter_func = _create_filter_func(
709
+ arr_shape=template.shape,
710
+ arr_ft_shape=fast_ft_shape,
711
+ arr_filter=template_filter,
712
+ rfftn=rfftn,
713
+ irfftn=irfftn,
714
+ **_fftargs,
715
+ )
716
+
717
+ eps = be.eps(float_dtype)
718
+ callback_func = conditional_execute(callback, callback is not None)
719
+ for index in range(rotations.shape[0]):
720
+ rotation = rotations[index]
721
+ arr = be.fill(arr, 0)
722
+ temp = be.fill(temp, 0)
723
+ arr, temp = be.rigid_transform(
724
+ arr=template,
725
+ arr_mask=template_mask,
726
+ rotation_matrix=rotation,
727
+ out=arr,
728
+ out_mask=temp,
729
+ use_geometric_center=True,
730
+ order=interpolation_order,
731
+ cache=False,
732
+ )
733
+
734
+ n_obs = be.sum(temp)
735
+ arr = template_filter_func(arr, ft_temp, template_filter)
736
+ arr = normalize_template(arr, temp, n_obs, axis=None)
737
+
738
+ ft_temp = rfftn(temp, ft_temp)
739
+ ft_denom = be.multiply(ft_target, ft_temp, out=ft_denom)
740
+ temp = irfftn(ft_denom, temp)
741
+ ft_denom = be.multiply(ft_target2, ft_temp, out=ft_denom)
742
+ temp2 = irfftn(ft_denom, temp2)
743
+
744
+ ft_temp = rfftn(arr, ft_temp)
745
+ ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp)
746
+ arr = irfftn(ft_temp, arr)
747
+
748
+ arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
749
+ callback_func(arr, rotation_matrix=rotation)
750
+
751
+ return callback
752
+
753
+
754
+ def mcc_scoring(
755
+ template: shm_type,
756
+ template_mask: shm_type,
757
+ template_filter: shm_type,
758
+ ft_target: shm_type,
759
+ ft_target2: shm_type,
760
+ ft_target_mask: shm_type,
761
+ fast_shape: Tuple[int],
762
+ fast_ft_shape: Tuple[int],
763
+ rotations: BackendArray,
764
+ callback: CallbackClass,
765
+ interpolation_order: int,
766
+ overlap_ratio: float = 0.3,
767
+ ) -> CallbackClass:
768
+ """
769
+ Computes a normalized cross-correlation score between ``target`` (f),
770
+ ``template`` (g), ``template_mask`` (m) and ``target_mask`` (t)
771
+
772
+ .. math::
773
+
774
+ \\frac{
775
+ CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
776
+ }{
777
+ \\sqrt{
778
+ (CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
779
+ (CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
780
+ }
781
+ },
782
+
783
+ where
784
+
785
+ .. math::
786
+
787
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*).
788
+
789
+ Parameters
790
+ ----------
791
+ template : Union[Tuple[type, tuple of ints, type], BackendArray]
792
+ Template data buffer, its shape and datatype.
793
+ template_mask : Union[Tuple[type, tuple of ints, type], BackendArray]
794
+ Template mask data buffer, its shape and datatype.
795
+ template_filter : Union[Tuple[type, tuple of ints, type], BackendArray]
796
+ Template filter data buffer, its shape and datatype.
797
+ ft_target : Union[Tuple[type, tuple of ints, type], BackendArray]
798
+ Fourier transformed target data buffer, its shape and datatype.
799
+ ft_target2 : Union[Tuple[type, tuple of ints, type], BackendArray]
800
+ Fourier transformed squared target data buffer, its shape and datatype.
801
+ ft_target_mask : Union[Tuple[type, tuple of ints, type], BackendArray]
802
+ Fourier transformed target mask data buffer, its shape and datatype.
803
+ fast_shape: tuple of ints
804
+ Data shape for the forward Fourier transform.
805
+ fast_ft_shape: tuple of ints
806
+ Data shape for the inverse Fourier transform.
807
+ rotations : BackendArray
808
+ Rotation matrices to be sampled (n, d, d).
809
+ callback : CallbackClass
810
+ A callable for processing the result of each rotation.
811
+ interpolation_order : int
812
+ Spline order for template rotations.
813
+ overlap_ratio : float, optional
814
+ Required fractional mask overlap, 0.3 by default.
815
+
816
+ References
817
+ ----------
818
+ .. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
819
+ .. [2] https://scikit-image.org/docs/stable/api/skimage.registration.html
820
+ """
821
+ float_dtype, complex_dtype = be._float_dtype, be._complex_dtype
822
+ template = be.from_sharedarr(template)
823
+ target_ft = be.from_sharedarr(ft_target)
824
+ target_ft2 = be.from_sharedarr(ft_target2)
825
+ template_mask = be.from_sharedarr(template_mask)
826
+ target_mask_ft = be.from_sharedarr(ft_target_mask)
827
+ template_filter = be.from_sharedarr(template_filter)
828
+
829
+ axes = tuple(range(template.ndim))
830
+ eps = be.eps(float_dtype)
831
+
832
+ # Allocate score and process specific arrays
833
+ template_rot = be.zeros(fast_shape, float_dtype)
834
+ mask_overlap = be.zeros(fast_shape, float_dtype)
835
+ numerator = be.zeros(fast_shape, float_dtype)
836
+ temp = be.zeros(fast_shape, float_dtype)
837
+ temp2 = be.zeros(fast_shape, float_dtype)
838
+ temp3 = be.zeros(fast_shape, float_dtype)
839
+ temp_ft = be.zeros(fast_ft_shape, complex_dtype)
840
+
841
+ _fftargs = {
842
+ "real_dtype": be._float_dtype,
843
+ "cmpl_dtype": be._complex_dtype,
844
+ "inv_output_shape": fast_shape,
845
+ "fwd_axes": None,
846
+ "inv_axes": None,
847
+ "inv_shape": fast_ft_shape,
848
+ "temp_fwd": temp,
849
+ }
850
+
851
+ _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
852
+ rfftn, irfftn = be.build_fft(temp_inv=temp_ft, **_fftargs)
853
+ _ = _fftargs.pop("temp_fwd", None)
854
+
855
+ template_filter_func = _create_filter_func(
856
+ arr_shape=template.shape,
857
+ arr_ft_shape=fast_ft_shape,
858
+ arr_filter=template_filter,
859
+ rfftn=rfftn,
860
+ irfftn=irfftn,
861
+ **_fftargs,
862
+ )
863
+
864
+ callback_func = conditional_execute(callback, callback is not None)
865
+ for index in range(rotations.shape[0]):
866
+ rotation = rotations[index]
867
+ template_rot = be.fill(template_rot, 0)
868
+ temp = be.fill(temp, 0)
869
+ be.rigid_transform(
870
+ arr=template,
871
+ arr_mask=template_mask,
872
+ rotation_matrix=rotation,
873
+ out=template_rot,
874
+ out_mask=temp,
875
+ use_geometric_center=True,
876
+ order=interpolation_order,
877
+ cache=False,
878
+ )
879
+
880
+ template_filter_func(template_rot, temp_ft, template_filter)
881
+ normalize_template(template_rot, temp, be.sum(temp))
882
+
883
+ temp_ft = rfftn(template_rot, temp_ft)
884
+ temp2 = irfftn(target_mask_ft * temp_ft, temp2)
885
+ numerator = irfftn(target_ft * temp_ft, numerator)
886
+
887
+ # temp template_mask_rot | temp_ft template_mask_rot_ft
888
+ # Calculate overlap of masks at every point in the convolution.
889
+ # Locations with high overlap should not be taken into account.
890
+ temp_ft = rfftn(temp, temp_ft)
891
+ mask_overlap = irfftn(temp_ft * target_mask_ft, mask_overlap)
892
+ be.maximum(mask_overlap, eps, out=mask_overlap)
893
+ temp = irfftn(temp_ft * target_ft, temp)
894
+
895
+ be.subtract(
896
+ numerator,
897
+ be.divide(be.multiply(temp, temp2), mask_overlap),
898
+ out=numerator,
899
+ )
900
+
901
+ # temp_3 = fixed_denom
902
+ be.multiply(temp_ft, target_ft2, out=temp_ft)
903
+ temp3 = irfftn(temp_ft, temp3)
904
+ be.subtract(temp3, be.divide(be.square(temp), mask_overlap), out=temp3)
905
+ be.maximum(temp3, 0.0, out=temp3)
906
+
907
+ # temp = moving_denom
908
+ temp_ft = rfftn(be.square(template_rot), temp_ft)
909
+ be.multiply(target_mask_ft, temp_ft, out=temp_ft)
910
+ temp = irfftn(temp_ft, temp)
911
+
912
+ be.subtract(temp, be.divide(be.square(temp2), mask_overlap), out=temp)
913
+ be.maximum(temp, 0.0, out=temp)
914
+
915
+ # temp_2 = denom
916
+ be.multiply(temp3, temp, out=temp)
917
+ be.sqrt(temp, temp2)
918
+
919
+ # Pixels where `denom` is very small will introduce large
920
+ # numbers after division. To get around this problem,
921
+ # we zero-out problematic pixels.
922
+ tol = 1e3 * eps * be.max(be.abs(temp2), axis=axes, keepdims=True)
923
+
924
+ temp2[temp2 < tol] = 1
925
+ temp = be.divide(numerator, temp2, out=temp)
926
+ temp = be.clip(temp, a_min=-1, a_max=1, out=temp)
927
+
928
+ # Apply overlap ratio threshold
929
+ number_px_threshold = overlap_ratio * be.max(
930
+ mask_overlap, axis=axes, keepdims=True
931
+ )
932
+ temp[mask_overlap < number_px_threshold] = 0.0
933
+ callback_func(temp, rotation_matrix=rotation)
934
+
935
+ return callback
936
+
937
+
938
+ def _format_slice(shape, squeeze_axis):
939
+ ret = tuple(
940
+ slice(None) if i not in squeeze_axis else 0 for i, _ in enumerate(shape)
941
+ )
942
+ return ret
943
+
944
+
945
+ def _get_batch_dim(target, template):
946
+ target_batch, template_batch = [], []
947
+ for i in range(len(target.shape)):
948
+ if target.shape[i] == 1 and template.shape[i] != 1:
949
+ template_batch.append(i)
950
+ if target.shape[i] != 1 and template.shape[i] == 1:
951
+ target_batch.append(i)
952
+
953
+ return target_batch, template_batch
954
+
955
+
956
+ def flc_scoring2(
957
+ template: shm_type,
958
+ template_mask: shm_type,
959
+ ft_target: shm_type,
960
+ ft_target2: shm_type,
961
+ template_filter: shm_type,
962
+ fast_shape: Tuple[int],
963
+ fast_ft_shape: Tuple[int],
964
+ rotations: BackendArray,
965
+ callback: CallbackClass,
966
+ interpolation_order: int,
967
+ ) -> CallbackClass:
968
+ callback_func = conditional_execute(callback, callback is not None)
969
+
970
+ # Retrieve objects from shared memory
971
+ template = be.from_sharedarr(template)
972
+ template_mask = be.from_sharedarr(template_mask)
973
+ ft_target = be.from_sharedarr(ft_target)
974
+ ft_target2 = be.from_sharedarr(ft_target2)
975
+ template_filter = be.from_sharedarr(template_filter)
976
+
977
+ data_axes = None
978
+ target_batch, template_batch = _get_batch_dim(ft_target, template)
979
+ sqz_cmpl = tuple(1 if i in target_batch else x for i, x in enumerate(fast_ft_shape))
980
+ sqz_slice = tuple(slice(0, 1) if x == 1 else slice(None) for x in sqz_cmpl)
981
+
982
+ data_shape = fast_shape
983
+ if len(target_batch) or len(template_batch):
984
+ batch = (*target_batch, *template_batch)
985
+ data_axes = tuple(i for i in range(len(fast_shape)) if i not in batch)
986
+ data_shape = tuple(fast_shape[i] for i in data_axes)
987
+
988
+ arr = be.zeros(fast_shape, be._float_dtype)
989
+ temp = be.zeros(fast_shape, be._float_dtype)
990
+ temp2 = be.zeros(fast_shape, be._float_dtype)
991
+ ft_denom = be.zeros(fast_ft_shape, be._complex_dtype)
992
+
993
+ tmp_sqz, arr_sqz, ft_temp = temp[sqz_slice], arr[sqz_slice], ft_denom[sqz_slice]
994
+ if be.size(template_filter) != 1:
995
+ ret_shape = np.broadcast_shapes(
996
+ sqz_cmpl, tuple(int(x) for x in template_filter.shape)
997
+ )
998
+ ft_temp = be.zeros(ret_shape, be._complex_dtype)
999
+
1000
+ _fftargs = {
1001
+ "real_dtype": be._float_dtype,
1002
+ "cmpl_dtype": be._complex_dtype,
1003
+ "inv_output_shape": fast_shape,
1004
+ "fwd_axes": data_axes,
1005
+ "inv_axes": data_axes,
1006
+ "inv_shape": fast_ft_shape,
1007
+ "temp_fwd": arr_sqz if _shape_match(ft_temp.shape, sqz_cmpl) else arr,
1008
+ }
1009
+
1010
+ # build_fft ignores fwd_shape if temp_fwd is given and serves only for bookkeeping
1011
+ _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
1012
+ rfftn, irfftn = be.build_fft(temp_inv=ft_denom, **_fftargs)
1013
+ _ = _fftargs.pop("temp_fwd", None)
1014
+
1015
+ template_filter_func = _create_filter_func(
1016
+ arr_shape=template.shape,
1017
+ arr_ft_shape=sqz_cmpl,
1018
+ arr_filter=template_filter,
1019
+ rfftn=rfftn,
1020
+ irfftn=irfftn,
1021
+ **_fftargs,
1022
+ )
1023
+
1024
+ eps = be.eps(be._float_dtype)
1025
+ for index in range(rotations.shape[0]):
1026
+ rotation = rotations[index]
1027
+ be.fill(arr, 0)
1028
+ be.fill(temp, 0)
1029
+ arr_sqz, tmp_sqz = be.rigid_transform(
1030
+ arr=template,
1031
+ arr_mask=template_mask,
1032
+ rotation_matrix=rotation,
1033
+ out=arr_sqz,
1034
+ out_mask=tmp_sqz,
1035
+ use_geometric_center=True,
1036
+ order=interpolation_order,
1037
+ cache=False,
1038
+ )
1039
+ n_obs = be.sum(tmp_sqz, axis=data_axes, keepdims=True)
1040
+ arr_norm = template_filter_func(arr_sqz, ft_temp, template_filter)
1041
+ arr_norm = normalize_template(arr_norm, tmp_sqz, n_obs, axis=data_axes)
1042
+
1043
+ ft_temp = be.rfftn(tmp_sqz, ft_temp, axes=data_axes)
1044
+ ft_denom = be.multiply(ft_target, ft_temp, out=ft_denom)
1045
+ temp = be.irfftn(ft_denom, temp, axes=data_axes, s=data_shape)
1046
+ ft_denom = be.multiply(ft_target2, ft_temp, out=ft_denom)
1047
+ temp2 = be.irfftn(ft_denom, temp2, axes=data_axes, s=data_shape)
1048
+
1049
+ ft_temp = rfftn(arr_norm, ft_denom)
1050
+ ft_denom = be.multiply(ft_target, ft_temp, out=ft_denom)
1051
+ arr = irfftn(ft_denom, arr)
1052
+
1053
+ be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
1054
+ callback_func(arr, rotation_matrix=rotation)
1055
+
1056
+ return callback
1057
+
1058
+
1059
+ def corr_scoring2(
1060
+ template: shm_type,
1061
+ template_filter: shm_type,
1062
+ ft_target: shm_type,
1063
+ inv_denominator: shm_type,
1064
+ numerator: shm_type,
1065
+ fast_shape: Tuple[int],
1066
+ fast_ft_shape: Tuple[int],
1067
+ rotations: BackendArray,
1068
+ callback: CallbackClass,
1069
+ interpolation_order: int,
1070
+ target_filter: shm_type = None,
1071
+ template_mask: shm_type = None,
1072
+ ) -> CallbackClass:
1073
+ template = be.from_sharedarr(template)
1074
+ ft_target = be.from_sharedarr(ft_target)
1075
+ inv_denominator = be.from_sharedarr(inv_denominator)
1076
+ numerator = be.from_sharedarr(numerator)
1077
+ template_filter = be.from_sharedarr(template_filter)
1078
+
1079
+ data_axes = None
1080
+ target_batch, template_batch = _get_batch_dim(ft_target, template)
1081
+ sqz_cmpl = tuple(1 if i in target_batch else x for i, x in enumerate(fast_ft_shape))
1082
+ sqz_slice = tuple(slice(0, 1) if x == 1 else slice(None) for x in sqz_cmpl)
1083
+ unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
1084
+ if len(target_batch) or len(template_batch):
1085
+ batch = (*target_batch, *template_batch)
1086
+ data_axes = tuple(i for i in range(len(fast_shape)) if i not in batch)
1087
+ unpadded_slice = tuple(
1088
+ slice(None) if i in batch else slice(0, x)
1089
+ for i, x in enumerate(template.shape)
1090
+ )
1091
+
1092
+ arr = be.zeros(fast_shape, be._float_dtype)
1093
+ ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
1094
+ arr_sqz, ft_sqz = arr[sqz_slice], ft_temp[sqz_slice]
1095
+
1096
+ if be.size(template_filter) != 1:
1097
+ # The filter could be w.r.t the unpadded template
1098
+ ret_shape = tuple(
1099
+ int(x * y) if x == 1 or y == 1 else y
1100
+ for x, y in zip(sqz_cmpl, template_filter.shape)
1101
+ )
1102
+ ft_sqz = be.zeros(ret_shape, be._complex_dtype)
1103
+
1104
+ norm_func, norm_template, mask_sum = normalize_template, False, 1
1105
+ if template_mask is not None:
1106
+ template_mask = be.from_sharedarr(template_mask)
1107
+ norm_template, mask_sum = True, be.sum(
1108
+ be.astype(template_mask, be._overflow_safe_dtype),
1109
+ axis=data_axes,
1110
+ keepdims=True,
1111
+ )
1112
+ if be.datatype_bytes(template_mask.dtype) == 2:
1113
+ norm_func = _normalize_template_overflow_safe
1114
+
1115
+ callback_func = conditional_execute(callback, callback is not None)
1116
+ norm_template = conditional_execute(norm_func, norm_template)
1117
+ norm_numerator = conditional_execute(
1118
+ be.subtract, identity, _shape_match(numerator.shape, fast_shape)
1119
+ )
1120
+ norm_denominator = conditional_execute(
1121
+ be.multiply, identity, _shape_match(inv_denominator.shape, fast_shape)
1122
+ )
1123
+
1124
+ _fftargs = {
1125
+ "real_dtype": be._float_dtype,
1126
+ "cmpl_dtype": be._complex_dtype,
1127
+ "fwd_axes": data_axes,
1128
+ "inv_axes": data_axes,
1129
+ "inv_shape": fast_ft_shape,
1130
+ "inv_output_shape": fast_shape,
1131
+ "temp_fwd": arr_sqz if _shape_match(ft_sqz.shape, sqz_cmpl) else arr,
1132
+ }
1133
+
1134
+ # build_fft ignores fwd_shape if temp_fwd is given and serves only for bookkeeping
1135
+ _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
1136
+ rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
1137
+ _ = _fftargs.pop("temp_fwd", None)
1138
+
1139
+ template_filter_func = _create_filter_func(
1140
+ arr_shape=template.shape,
1141
+ arr_ft_shape=sqz_cmpl,
1142
+ arr_filter=template_filter,
1143
+ rfftn=rfftn,
1144
+ irfftn=irfftn,
1145
+ **_fftargs,
1146
+ )
1147
+
1148
+ for index in range(rotations.shape[0]):
1149
+ be.fill(arr, 0)
1150
+ rotation = rotations[index]
1151
+ arr_sqz, _ = be.rigid_transform(
1152
+ arr=template,
1153
+ rotation_matrix=rotation,
1154
+ out=arr_sqz,
1155
+ use_geometric_center=True,
1156
+ order=interpolation_order,
1157
+ cache=False,
1158
+ )
1159
+ arr_norm = template_filter_func(arr_sqz, ft_sqz, template_filter)
1160
+ norm_template(arr_norm[unpadded_slice], template_mask, mask_sum, axis=data_axes)
1161
+
1162
+ ft_sqz = rfftn(arr_norm, ft_sqz)
1163
+ ft_temp = be.multiply(ft_target, ft_sqz, out=ft_temp)
1164
+ arr = irfftn(ft_temp, arr)
1165
+
1166
+ arr = norm_numerator(arr, numerator, out=arr)
1167
+ arr = norm_denominator(arr, inv_denominator, out=arr)
1168
+ callback_func(arr, rotation_matrix=rotation)
1169
+
1170
+ return callback
1171
+
1172
+
1173
+ MATCHING_EXHAUSTIVE_REGISTER = {
1174
+ "CC": (cc_setup, corr_scoring),
1175
+ "LCC": (lcc_setup, corr_scoring),
1176
+ "CORR": (corr_setup, corr_scoring),
1177
+ "CAM": (cam_setup, corr_scoring),
1178
+ "FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
1179
+ "FLC": (flc_setup, flc_scoring),
1180
+ "MCC": (mcc_setup, mcc_scoring),
1181
+ "batchFLCSpherical": (flcSphericalMask_setup, corr_scoring2),
1182
+ "batchFLC": (flc_setup, flc_scoring2),
1183
+ }