pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
  2. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
  3. pytme-0.2.3.data/scripts/preprocess.py +132 -0
  4. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
  5. pytme-0.2.3.dist-info/METADATA +92 -0
  6. pytme-0.2.3.dist-info/RECORD +75 -0
  7. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
  8. pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
  9. scripts/extract_candidates.py +20 -13
  10. scripts/match_template.py +219 -216
  11. scripts/match_template_filters.py +154 -95
  12. scripts/postprocess.py +86 -54
  13. scripts/preprocess.py +95 -56
  14. scripts/preprocessor_gui.py +181 -94
  15. scripts/refine_matches.py +265 -61
  16. tme/__init__.py +0 -1
  17. tme/__version__.py +1 -1
  18. tme/analyzer.py +458 -813
  19. tme/backends/__init__.py +40 -11
  20. tme/backends/_jax_utils.py +187 -0
  21. tme/backends/cupy_backend.py +109 -226
  22. tme/backends/jax_backend.py +230 -152
  23. tme/backends/matching_backend.py +445 -384
  24. tme/backends/mlx_backend.py +32 -59
  25. tme/backends/npfftw_backend.py +240 -507
  26. tme/backends/pytorch_backend.py +30 -151
  27. tme/density.py +248 -371
  28. tme/extensions.cpython-311-darwin.so +0 -0
  29. tme/matching_data.py +328 -284
  30. tme/matching_exhaustive.py +195 -1499
  31. tme/matching_optimization.py +143 -106
  32. tme/matching_scores.py +887 -0
  33. tme/matching_utils.py +287 -388
  34. tme/memory.py +377 -0
  35. tme/orientations.py +78 -21
  36. tme/parser.py +3 -4
  37. tme/preprocessing/_utils.py +61 -32
  38. tme/preprocessing/composable_filter.py +7 -4
  39. tme/preprocessing/compose.py +7 -3
  40. tme/preprocessing/frequency_filters.py +49 -39
  41. tme/preprocessing/tilt_series.py +44 -72
  42. tme/preprocessor.py +560 -526
  43. tme/structure.py +491 -188
  44. tme/types.py +5 -3
  45. pytme-0.2.1.dist-info/METADATA +0 -73
  46. pytme-0.2.1.dist-info/RECORD +0 -73
  47. tme/helpers.py +0 -881
  48. tme/matching_constrained.py +0 -195
  49. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
  50. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
tme/matching_scores.py ADDED
@@ -0,0 +1,887 @@
1
+ """ Implements a range of cross-correlation coefficients.
2
+
3
+ Copyright (c) 2023-2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ import warnings
9
+ from typing import Callable, Tuple, Dict, Optional
10
+
11
+ import numpy as np
12
+ from scipy.ndimage import laplace
13
+
14
+ from .backends import backend as be
15
+ from .types import CallbackClass, BackendArray, shm_type
16
+ from .matching_utils import (
17
+ conditional_execute,
18
+ identity,
19
+ normalize_template,
20
+ _normalize_template_overflow_safe,
21
+ )
22
+
23
+
24
+ def _shape_match(shape1: Tuple[int], shape2: Tuple[int]) -> bool:
25
+ """
26
+ Determine whether ``shape1`` is equal to ``shape2``.
27
+
28
+ Parameters
29
+ ----------
30
+ shape1, shape2 : tuple of ints
31
+ Shapes to compare.
32
+
33
+ Returns
34
+ -------
35
+ Bool
36
+ ``shape1`` is equal to ``shape2``.
37
+ """
38
+ if len(shape1) != len(shape2):
39
+ return False
40
+ return shape1 == shape2
41
+
42
+
43
+ def _setup_template_filtering(
44
+ forward_ft_shape: Tuple[int],
45
+ inverse_ft_shape: Tuple[int],
46
+ template_shape: Tuple[int],
47
+ template_filter: BackendArray,
48
+ rfftn: Callable = None,
49
+ irfftn: Callable = None,
50
+ ) -> Callable:
51
+ """
52
+ Configure template filtering function for Fourier transforms.
53
+
54
+ Parameters
55
+ ----------
56
+ forward_ft_shape : tuple of ints
57
+ Shape for the forward Fourier transform.
58
+ inverse_ft_shape : tuple of ints
59
+ Shape for the inverse Fourier transform.
60
+ template_shape : tuple of ints
61
+ Shape of the template to be filtered.
62
+ template_filter : BackendArray
63
+ Precomputed filter to apply in the frequency domain.
64
+ rfftn : Callable, optional
65
+ Real-to-complex FFT function.
66
+ irfftn : Callable, optional
67
+ Complex-to-real inverse FFT function.
68
+
69
+ Returns
70
+ -------
71
+ Callable
72
+ Filter function with parameters template, ft_temp and template_filter.
73
+
74
+ Notes
75
+ -----
76
+ If the shape of template_filter does not match inverse_ft_shape
77
+ the template is assumed to be padded and cropped back to template_shape
78
+ prior to filter application.
79
+ """
80
+ if be.size(template_filter) == 1:
81
+ return conditional_execute(identity, identity, False)
82
+
83
+ shape_mismatch = False
84
+ if not _shape_match(template_filter.shape, inverse_ft_shape):
85
+ shape_mismatch = True
86
+ forward_ft_shape = template_shape
87
+ inverse_ft_shape = template_filter.shape
88
+
89
+ if (rfftn is not None and irfftn is not None) or shape_mismatch:
90
+ rfftn, irfftn = be.build_fft(
91
+ fast_shape=forward_ft_shape,
92
+ fast_ft_shape=inverse_ft_shape,
93
+ real_dtype=be._float_dtype,
94
+ complex_dtype=be._complex_dtype,
95
+ inverse_fast_shape=forward_ft_shape,
96
+ )
97
+
98
+ # Default case, all shapes are correctly matched
99
+ def _apply_template_filter(template, ft_temp, template_filter):
100
+ ft_temp = rfftn(template, ft_temp)
101
+ ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
102
+ return irfftn(ft_temp, template)
103
+
104
+ # Template is padded, filter is not. Crop and assign for continuous arrays
105
+ if shape_mismatch:
106
+ real_subset = tuple(slice(0, x) for x in forward_ft_shape)
107
+ _template = be.zeros(forward_ft_shape, be._float_dtype)
108
+ _ft_temp = be.zeros(inverse_ft_shape, be._complex_dtype)
109
+
110
+ def _apply_filter_shape_mismatch(template, ft_temp, template_filter):
111
+ _template[:] = template[real_subset]
112
+ template[real_subset] = _apply_template_filter(
113
+ _template, _ft_temp, template_filter
114
+ )
115
+ return template
116
+
117
+ return _apply_filter_shape_mismatch
118
+
119
+ return _apply_template_filter
120
+
121
+
122
+ def cc_setup(
123
+ rfftn: Callable,
124
+ irfftn: Callable,
125
+ template: BackendArray,
126
+ target: BackendArray,
127
+ fast_shape: Tuple[int],
128
+ fast_ft_shape: Tuple[int],
129
+ shared_memory_handler: type,
130
+ **kwargs,
131
+ ) -> Dict:
132
+ """
133
+ Setup function for comuting a unnormalized cross-correlation between
134
+ ``target`` (f) and ``template`` (g)
135
+
136
+ .. math::
137
+
138
+ \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*).
139
+
140
+
141
+ Notes
142
+ -----
143
+ To be used with :py:meth:`corr_scoring`.
144
+ """
145
+ target_pad_ft = be.zeros(fast_ft_shape, be._complex_dtype)
146
+ target_pad_ft = rfftn(be.topleft_pad(target, fast_shape), target_pad_ft)
147
+ numerator = be.zeros(1, be._float_dtype)
148
+ inv_denominator = be.zeros(1, be._float_dtype) + 1
149
+
150
+ ret = {
151
+ "fast_shape": fast_shape,
152
+ "fast_ft_shape": fast_ft_shape,
153
+ "template": be.to_sharedarr(template, shared_memory_handler),
154
+ "ft_target": be.to_sharedarr(target_pad_ft, shared_memory_handler),
155
+ "inv_denominator": be.to_sharedarr(inv_denominator, shared_memory_handler),
156
+ "numerator": be.to_sharedarr(numerator, shared_memory_handler),
157
+ }
158
+
159
+ return ret
160
+
161
+
162
+ def lcc_setup(target: BackendArray, template: BackendArray, **kwargs) -> Dict:
163
+ """
164
+ Setup function for computing a laplace cross-correlation between
165
+ ``target`` (f) and ``template`` (g)
166
+
167
+ .. math::
168
+
169
+ \\mathcal{F}^{-1}(\\mathcal{F}(\\nabla^{2}f) \\cdot \\mathcal{F}(\\nabla^{2} g)^*)
170
+
171
+
172
+ Notes
173
+ -----
174
+ To be used with :py:meth:`corr_scoring`.
175
+ """
176
+ target, template = be.to_numpy_array(target), be.to_numpy_array(template)
177
+ kwargs["target"] = be.to_backend_array(laplace(target, mode="wrap"))
178
+ kwargs["template"] = be.to_backend_array(laplace(template, mode="wrap"))
179
+ return cc_setup(**kwargs)
180
+
181
+
182
+ def corr_setup(
183
+ rfftn: Callable,
184
+ irfftn: Callable,
185
+ template: BackendArray,
186
+ template_mask: BackendArray,
187
+ template_filter: BackendArray,
188
+ target: BackendArray,
189
+ fast_shape: Tuple[int],
190
+ fast_ft_shape: Tuple[int],
191
+ shared_memory_handler: type,
192
+ **kwargs,
193
+ ) -> Dict:
194
+ """
195
+ Setup for computing a normalized cross-correlation between a
196
+ ``target`` (f), a ``template`` (g) given ``template_mask`` (m)
197
+
198
+ .. math::
199
+
200
+ \\frac{CC(f,g) - \\overline{g} \\cdot CC(f, m)}
201
+ {(CC(f^2, m) - \\frac{CC(f, m)^2}{N_g}) \\cdot \\sigma_{g}},
202
+
203
+ where
204
+
205
+ .. math::
206
+
207
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*).
208
+
209
+ Notes
210
+ -----
211
+ To be used with :py:meth:`corr_scoring`.
212
+
213
+ References
214
+ ----------
215
+ .. [1] Lewis P. J. Fast Normalized Cross-Correlation, Industrial Light and Magic.
216
+ """
217
+ target_pad = be.topleft_pad(target, fast_shape)
218
+
219
+ # The exact composition of the denominator is debatable
220
+ # scikit-image match_template multiplies the running sum of the target
221
+ # with a scaling factor derived from the template. This is probably appropriate
222
+ # in pattern matching situations where the template exists in the target
223
+ ft_window = be.zeros(fast_ft_shape, be._complex_dtype)
224
+ ft_window = rfftn(be.topleft_pad(template_mask, fast_shape), ft_window)
225
+ ft_target = be.zeros(fast_ft_shape, be._complex_dtype)
226
+ ft_target2 = be.zeros(fast_ft_shape, be._complex_dtype)
227
+ denominator = be.zeros(fast_shape, be._float_dtype)
228
+ window_sum = be.zeros(fast_shape, be._float_dtype)
229
+
230
+ ft_target = rfftn(target_pad, ft_target)
231
+ ft_target2 = rfftn(be.square(target_pad), ft_target2)
232
+ ft_target2 = be.multiply(ft_target2, ft_window, out=ft_target2)
233
+ denominator = irfftn(ft_target2, denominator)
234
+ ft_window = be.multiply(ft_target, ft_window, out=ft_window)
235
+ window_sum = irfftn(ft_window, window_sum)
236
+
237
+ target_pad, ft_target2, ft_window = None, None, None
238
+
239
+ # TODO: Factor in template_filter here
240
+ if be.size(template_filter) != 1:
241
+ warnings.warn(
242
+ "CORR scores obtained with template_filter are not correctly scaled. "
243
+ "Please use a different score or consider only relative peak heights."
244
+ )
245
+ n_obs, norm_func = be.sum(template_mask), normalize_template
246
+ if be.datatype_bytes(template_mask.dtype) == 2:
247
+ norm_func = _normalize_template_overflow_safe
248
+ n_obs = be.sum(be.astype(template_mask, be._overflow_safe_dtype))
249
+
250
+ template = norm_func(template, template_mask, n_obs)
251
+ template_mean = be.sum(be.multiply(template, template_mask))
252
+ template_mean = be.divide(template_mean, n_obs)
253
+ template_ssd = be.sum(be.square(template - template_mean) * template_mask)
254
+ template_volume = np.prod(tuple(int(x) for x in template.shape))
255
+ template = be.multiply(template, template_mask, out=template)
256
+
257
+ numerator = be.multiply(window_sum, template_mean)
258
+ window_sum = be.square(window_sum, out=window_sum)
259
+ window_sum = be.divide(window_sum, template_volume, out=window_sum)
260
+ denominator = be.subtract(denominator, window_sum, out=denominator)
261
+ denominator = be.multiply(denominator, template_ssd, out=denominator)
262
+ denominator = be.maximum(denominator, 0, out=denominator)
263
+ denominator = be.sqrt(denominator, out=denominator)
264
+
265
+ mask = denominator > be.eps(be._float_dtype)
266
+ denominator = be.multiply(denominator, mask, out=denominator)
267
+ denominator = be.add(denominator, ~mask, out=denominator)
268
+ denominator = be.divide(1, denominator, out=denominator)
269
+ denominator = be.multiply(denominator, mask, out=denominator)
270
+
271
+ ret = {
272
+ "fast_shape": fast_shape,
273
+ "fast_ft_shape": fast_ft_shape,
274
+ "template": be.to_sharedarr(template, shared_memory_handler),
275
+ "ft_target": be.to_sharedarr(ft_target, shared_memory_handler),
276
+ "inv_denominator": be.to_sharedarr(denominator, shared_memory_handler),
277
+ "numerator": be.to_sharedarr(numerator, shared_memory_handler),
278
+ }
279
+
280
+ return ret
281
+
282
+
283
+ def cam_setup(template: BackendArray, target: BackendArray, **kwargs) -> Dict:
284
+ """
285
+ Like :py:meth:`corr_setup` but with standardized ``target``, ``template``
286
+
287
+ .. math::
288
+
289
+ f' = \\frac{f - \\overline{f}}{\\sigma_f}.
290
+
291
+ Notes
292
+ -----
293
+ To be used with :py:meth:`corr_scoring`.
294
+ """
295
+ template = (template - be.mean(template)) / be.std(template)
296
+ target = (target - be.mean(target)) / be.std(target)
297
+ return corr_setup(template=template, target=target, **kwargs)
298
+
299
+
300
+ def flc_setup(
301
+ rfftn: Callable,
302
+ irfftn: Callable,
303
+ template: BackendArray,
304
+ template_mask: BackendArray,
305
+ target: BackendArray,
306
+ fast_shape: Tuple[int],
307
+ fast_ft_shape: Tuple[int],
308
+ shared_memory_handler: type,
309
+ **kwargs,
310
+ ) -> Dict:
311
+ """
312
+ Setup function for :py:meth:`flc_scoring`.
313
+ """
314
+ target_pad = be.topleft_pad(target, fast_shape)
315
+ ft_target = be.zeros(fast_ft_shape, be._complex_dtype)
316
+ ft_target2 = be.zeros(fast_ft_shape, be._complex_dtype)
317
+
318
+ ft_target = rfftn(target_pad, ft_target)
319
+ target_pad = be.square(target_pad, out=target_pad)
320
+ ft_target2 = rfftn(target_pad, ft_target2)
321
+ template = normalize_template(template, template_mask, be.sum(template_mask))
322
+
323
+ ret = {
324
+ "fast_shape": fast_shape,
325
+ "fast_ft_shape": fast_ft_shape,
326
+ "template": be.to_sharedarr(template, shared_memory_handler),
327
+ "template_mask": be.to_sharedarr(template_mask, shared_memory_handler),
328
+ "ft_target": be.to_sharedarr(ft_target, shared_memory_handler),
329
+ "ft_target2": be.to_sharedarr(ft_target2, shared_memory_handler),
330
+ }
331
+
332
+ return ret
333
+
334
+
335
+ def flcSphericalMask_setup(
336
+ rfftn: Callable,
337
+ irfftn: Callable,
338
+ template: BackendArray,
339
+ template_mask: BackendArray,
340
+ target: BackendArray,
341
+ fast_shape: Tuple[int],
342
+ fast_ft_shape: Tuple[int],
343
+ shared_memory_handler: type,
344
+ **kwargs,
345
+ ) -> Dict:
346
+ """
347
+ Setup for :py:meth:`corr_scoring`, like :py:meth:`flc_setup` but for rotation
348
+ invariant masks.
349
+ """
350
+ n_obs, norm_func = be.sum(template_mask), normalize_template
351
+ if be.datatype_bytes(template_mask.dtype) == 2:
352
+ norm_func = _normalize_template_overflow_safe
353
+ n_obs = be.sum(be.astype(template_mask, be._overflow_safe_dtype))
354
+
355
+ target_pad = be.topleft_pad(target, fast_shape)
356
+ temp = be.zeros(fast_shape, be._float_dtype)
357
+ temp2 = be.zeros(fast_shape, be._float_dtype)
358
+ numerator = be.zeros(1, be._float_dtype)
359
+ ft_target = be.zeros(fast_ft_shape, be._complex_dtype)
360
+ ft_template_mask = be.zeros(fast_ft_shape, be._complex_dtype)
361
+ ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
362
+
363
+ template = norm_func(template, template_mask, n_obs)
364
+ ft_template_mask = rfftn(
365
+ be.topleft_pad(template_mask, fast_shape), ft_template_mask
366
+ )
367
+
368
+ # E(X^2) - E(X)^2
369
+ ft_target = rfftn(be.square(target_pad), ft_target)
370
+ ft_temp = be.multiply(ft_target, ft_template_mask, out=ft_temp)
371
+ temp2 = irfftn(ft_temp, temp2)
372
+ temp2 = be.divide(temp2, n_obs, out=temp2)
373
+
374
+ ft_target = rfftn(target_pad, ft_target)
375
+ ft_temp = be.multiply(ft_target, ft_template_mask, out=ft_temp)
376
+ temp = irfftn(ft_temp, temp)
377
+ temp = be.divide(temp, n_obs, out=temp)
378
+ temp = be.square(temp, out=temp)
379
+
380
+ temp = be.subtract(temp2, temp, out=temp)
381
+ temp = be.maximum(temp, 0.0, out=temp)
382
+ temp = be.sqrt(temp, out=temp)
383
+
384
+ # Avoide divide by zero warnings
385
+ mask = temp > be.eps(be._float_dtype)
386
+ temp = be.multiply(temp, mask * n_obs, out=temp)
387
+ temp = be.add(temp, ~mask, out=temp)
388
+ temp2 = be.divide(1, temp, out=temp)
389
+ temp2 = be.multiply(temp2, mask, out=temp2)
390
+
391
+ ret = {
392
+ "fast_shape": fast_shape,
393
+ "fast_ft_shape": fast_ft_shape,
394
+ "template": be.to_sharedarr(template, shared_memory_handler),
395
+ "template_mask": be.to_sharedarr(template_mask, shared_memory_handler),
396
+ "ft_target": be.to_sharedarr(ft_target, shared_memory_handler),
397
+ "inv_denominator": be.to_sharedarr(temp2, shared_memory_handler),
398
+ "numerator": be.to_sharedarr(numerator, shared_memory_handler),
399
+ }
400
+
401
+ return ret
402
+
403
+
404
+ def mcc_setup(
405
+ rfftn: Callable,
406
+ irfftn: Callable,
407
+ template: BackendArray,
408
+ template_mask: BackendArray,
409
+ target: BackendArray,
410
+ target_mask: BackendArray,
411
+ fast_shape: Tuple[int],
412
+ fast_ft_shape: Tuple[int],
413
+ shared_memory_handler: Callable,
414
+ **kwargs,
415
+ ) -> Dict:
416
+ """
417
+ Setup function for :py:meth:`mcc_scoring`.
418
+ """
419
+ target = be.multiply(target, target_mask > 0, out=target)
420
+ target_pad = be.topleft_pad(target, fast_shape)
421
+
422
+ ft_target = be.zeros(fast_ft_shape, be._complex_dtype)
423
+ ft_target2 = be.zeros(fast_ft_shape, be._complex_dtype)
424
+ target_mask_ft = be.zeros(fast_ft_shape, be._complex_dtype)
425
+
426
+ ft_target = rfftn(target_pad, ft_target)
427
+ ft_target2 = rfftn(be.square(target_pad), ft_target2)
428
+ target_mask_ft = rfftn(be.topleft_pad(target_mask, fast_shape), target_mask_ft)
429
+
430
+ ret = {
431
+ "fast_shape": fast_shape,
432
+ "fast_ft_shape": fast_ft_shape,
433
+ "template": be.to_sharedarr(template, shared_memory_handler),
434
+ "template_mask": be.to_sharedarr(template_mask, shared_memory_handler),
435
+ "ft_target": be.to_sharedarr(ft_target, shared_memory_handler),
436
+ "ft_target2": be.to_sharedarr(ft_target2, shared_memory_handler),
437
+ "ft_target_mask": be.to_sharedarr(target_mask_ft, shared_memory_handler),
438
+ }
439
+
440
+ return ret
441
+
442
+
443
+ def corr_scoring(
444
+ template: shm_type,
445
+ template_filter: shm_type,
446
+ ft_target: shm_type,
447
+ inv_denominator: shm_type,
448
+ numerator: shm_type,
449
+ fast_shape: Tuple[int],
450
+ fast_ft_shape: Tuple[int],
451
+ rotations: BackendArray,
452
+ callback: CallbackClass,
453
+ interpolation_order: int,
454
+ template_mask: shm_type = None,
455
+ ) -> Optional[CallbackClass]:
456
+ """
457
+ Calculates a normalized cross-correlation between a target f and a template g.
458
+
459
+ .. math::
460
+
461
+ (CC(f,g) - \\text{numerator}) \\cdot \\text{inv_denominator},
462
+
463
+ where
464
+
465
+ .. math::
466
+
467
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*).
468
+
469
+ Parameters
470
+ ----------
471
+ template : Union[Tuple[type, tuple of ints, type], BackendArray]
472
+ Template data buffer, its shape and datatype.
473
+ template_filter : Union[Tuple[type, tuple of ints, type], BackendArray]
474
+ Template filter data buffer, its shape and datatype.
475
+ ft_target : Union[Tuple[type, tuple of ints, type], BackendArray]
476
+ Fourier transformed target data buffer, its shape and datatype.
477
+ inv_denominator : Union[Tuple[type, tuple of ints, type], BackendArray]
478
+ Inverse denominator data buffer, its shape and datatype.
479
+ numerator : Union[Tuple[type, tuple of ints, type], BackendArray]
480
+ Numerator data buffer, its shape, and its datatype.
481
+ fast_shape: tuple of ints
482
+ Data shape for the forward Fourier transform.
483
+ fast_ft_shape: tuple of ints
484
+ Data shape for the inverse Fourier transform.
485
+ rotations : BackendArray
486
+ Rotation matrices to be sampled (n, d, d).
487
+ callback : CallbackClass
488
+ A callable for processing the result of each rotation.
489
+ interpolation_order : int
490
+ Spline order for template rotations.
491
+ template_mask : Union[Tuple[type, tuple of ints, type], BackendArray], optional
492
+ Template mask data buffer, its shape and datatype, None by default.
493
+
494
+ Returns
495
+ -------
496
+ Optional[CallbackClass]
497
+ ``callback`` if provided otherwise None.
498
+ """
499
+ template = be.from_sharedarr(template)
500
+ ft_target = be.from_sharedarr(ft_target)
501
+ inv_denominator = be.from_sharedarr(inv_denominator)
502
+ numerator = be.from_sharedarr(numerator)
503
+ template_filter = be.from_sharedarr(template_filter)
504
+
505
+ norm_func, norm_template, mask_sum = normalize_template, False, 1
506
+ if template_mask is not None:
507
+ template_mask = be.from_sharedarr(template_mask)
508
+ norm_template, mask_sum = True, be.sum(template_mask)
509
+ if be.datatype_bytes(template_mask.dtype) == 2:
510
+ norm_func = _normalize_template_overflow_safe
511
+ mask_sum = be.sum(be.astype(template_mask, be._overflow_safe_dtype))
512
+
513
+ callback_func = conditional_execute(callback, callback is not None)
514
+ norm_template = conditional_execute(norm_func, norm_template)
515
+ norm_numerator = conditional_execute(
516
+ be.subtract, identity, _shape_match(numerator.shape, fast_shape)
517
+ )
518
+ norm_denominator = conditional_execute(
519
+ be.multiply, identity, _shape_match(inv_denominator.shape, fast_shape)
520
+ )
521
+
522
+ arr = be.zeros(fast_shape, be._float_dtype)
523
+ ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
524
+ rfftn, irfftn = be.build_fft(
525
+ fast_shape=fast_shape,
526
+ fast_ft_shape=fast_ft_shape,
527
+ real_dtype=be._float_dtype,
528
+ complex_dtype=be._complex_dtype,
529
+ temp_real=arr,
530
+ temp_fft=ft_temp,
531
+ )
532
+
533
+ template_filter_func = _setup_template_filtering(
534
+ forward_ft_shape=fast_shape,
535
+ inverse_ft_shape=fast_ft_shape,
536
+ template_shape=template.shape,
537
+ template_filter=template_filter,
538
+ rfftn=rfftn,
539
+ irfftn=irfftn,
540
+ )
541
+ unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
542
+ for index in range(rotations.shape[0]):
543
+ rotation = rotations[index]
544
+ arr = be.fill(arr, 0)
545
+ arr, _ = be.rigid_transform(
546
+ arr=template,
547
+ rotation_matrix=rotation,
548
+ out=arr,
549
+ use_geometric_center=True,
550
+ order=interpolation_order,
551
+ cache=True,
552
+ )
553
+ arr = template_filter_func(arr, ft_temp, template_filter)
554
+ norm_template(arr[unpadded_slice], template_mask, mask_sum)
555
+
556
+ ft_temp = rfftn(arr, ft_temp)
557
+ ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp)
558
+ arr = irfftn(ft_temp, arr)
559
+
560
+ arr = norm_numerator(arr, numerator, out=arr)
561
+ arr = norm_denominator(arr, inv_denominator, out=arr)
562
+ callback_func(arr, rotation_matrix=rotation)
563
+
564
+ return callback
565
+
566
+
567
+ def flc_scoring(
568
+ template: shm_type,
569
+ template_mask: shm_type,
570
+ ft_target: shm_type,
571
+ ft_target2: shm_type,
572
+ template_filter: shm_type,
573
+ fast_shape: Tuple[int],
574
+ fast_ft_shape: Tuple[int],
575
+ rotations: BackendArray,
576
+ callback: CallbackClass,
577
+ interpolation_order: int,
578
+ ) -> Optional[CallbackClass]:
579
+ """
580
+ Computes a normalized cross-correlation between ``target`` (f),
581
+ ``template`` (g), and ``template_mask`` (m)
582
+
583
+ .. math::
584
+
585
+ \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
586
+ {N_m * \\sqrt{
587
+ \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
588
+ },
589
+
590
+ where
591
+
592
+ .. math::
593
+
594
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
595
+
596
+ and Nm is the sum of g.
597
+
598
+ Parameters
599
+ ----------
600
+ template : Union[Tuple[type, tuple of ints, type], BackendArray]
601
+ Template data buffer, its shape and datatype.
602
+ template_mask : Union[Tuple[type, tuple of ints, type], BackendArray]
603
+ Template mask data buffer, its shape and datatype.
604
+ template_filter : Union[Tuple[type, tuple of ints, type], BackendArray]
605
+ Template filter data buffer, its shape and datatype.
606
+ ft_target : Union[Tuple[type, tuple of ints, type], BackendArray]
607
+ Fourier transformed target data buffer, its shape and datatype.
608
+ ft_target2 : Union[Tuple[type, tuple of ints, type], BackendArray]
609
+ Fourier transformed squared target data buffer, its shape and datatype.
610
+ fast_shape : tuple of ints
611
+ Data shape for the forward Fourier transform.
612
+ fast_ft_shape : tuple of ints
613
+ Data shape for the inverse Fourier transform.
614
+ rotations : BackendArray
615
+ Rotation matrices to be sampled (n, d, d).
616
+ callback : CallbackClass
617
+ A callable for processing the result of each rotation.
618
+ callback_class_args : Dict
619
+ Dictionary of arguments to be passed to ``callback``.
620
+ interpolation_order : int
621
+ Spline order for template rotations.
622
+
623
+ Returns
624
+ -------
625
+ Optional[CallbackClass]
626
+ ``callback`` if provided otherwise None.
627
+
628
+ References
629
+ ----------
630
+ .. [1] Hrabe T. et al, J. Struct. Biol. 178, 177 (2012).
631
+ """
632
+ float_dtype, complex_dtype = be._float_dtype, be._complex_dtype
633
+ template = be.from_sharedarr(template)
634
+ template_mask = be.from_sharedarr(template_mask)
635
+ ft_target = be.from_sharedarr(ft_target)
636
+ ft_target2 = be.from_sharedarr(ft_target2)
637
+ template_filter = be.from_sharedarr(template_filter)
638
+
639
+ arr = be.zeros(fast_shape, float_dtype)
640
+ temp = be.zeros(fast_shape, float_dtype)
641
+ temp2 = be.zeros(fast_shape, float_dtype)
642
+ ft_temp = be.zeros(fast_ft_shape, complex_dtype)
643
+ ft_denom = be.zeros(fast_ft_shape, complex_dtype)
644
+
645
+ rfftn, irfftn = be.build_fft(
646
+ fast_shape=fast_shape,
647
+ fast_ft_shape=fast_ft_shape,
648
+ real_dtype=float_dtype,
649
+ complex_dtype=complex_dtype,
650
+ temp_real=arr,
651
+ temp_fft=ft_temp,
652
+ )
653
+
654
+ template_filter_func = _setup_template_filtering(
655
+ forward_ft_shape=fast_shape,
656
+ inverse_ft_shape=fast_ft_shape,
657
+ template_shape=template.shape,
658
+ template_filter=template_filter,
659
+ rfftn=rfftn,
660
+ irfftn=irfftn,
661
+ )
662
+
663
+ eps = be.eps(float_dtype)
664
+ callback_func = conditional_execute(callback, callback is not None)
665
+ for index in range(rotations.shape[0]):
666
+ rotation = rotations[index]
667
+ arr = be.fill(arr, 0)
668
+ temp = be.fill(temp, 0)
669
+ arr, temp = be.rigid_transform(
670
+ arr=template,
671
+ arr_mask=template_mask,
672
+ rotation_matrix=rotations[index],
673
+ out=arr,
674
+ out_mask=temp,
675
+ use_geometric_center=True,
676
+ order=interpolation_order,
677
+ cache=True,
678
+ )
679
+
680
+ n_obs = be.sum(temp)
681
+ arr = template_filter_func(arr, ft_temp, template_filter)
682
+ arr = normalize_template(arr, temp, n_obs)
683
+
684
+ ft_temp = rfftn(temp, ft_temp)
685
+ ft_denom = be.multiply(ft_target, ft_temp, out=ft_denom)
686
+ temp = irfftn(ft_denom, temp)
687
+ ft_denom = be.multiply(ft_target2, ft_temp, out=ft_denom)
688
+ temp2 = irfftn(ft_denom, temp2)
689
+
690
+ ft_temp = rfftn(arr, ft_temp)
691
+ ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp)
692
+ arr = irfftn(ft_temp, arr)
693
+
694
+ arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
695
+ callback_func(arr, rotation_matrix=rotation)
696
+
697
+ return callback
698
+
699
+
700
+ def mcc_scoring(
701
+ template: shm_type,
702
+ template_mask: shm_type,
703
+ template_filter: shm_type,
704
+ ft_target: shm_type,
705
+ ft_target2: shm_type,
706
+ ft_target_mask: shm_type,
707
+ fast_shape: Tuple[int],
708
+ fast_ft_shape: Tuple[int],
709
+ rotations: BackendArray,
710
+ callback: CallbackClass,
711
+ interpolation_order: int,
712
+ overlap_ratio: float = 0.3,
713
+ ) -> CallbackClass:
714
+ """
715
+ Computes a normalized cross-correlation score between ``target`` (f),
716
+ ``template`` (g), ``template_mask`` (m) and ``target_mask`` (t)
717
+
718
+ .. math::
719
+
720
+ \\frac{
721
+ CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)}
722
+ }{
723
+ \\sqrt{
724
+ (CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot
725
+ (CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)})
726
+ }
727
+ },
728
+
729
+ where
730
+
731
+ .. math::
732
+
733
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*).
734
+
735
+ Parameters
736
+ ----------
737
+ template : Union[Tuple[type, tuple of ints, type], BackendArray]
738
+ Template data buffer, its shape and datatype.
739
+ template_mask : Union[Tuple[type, tuple of ints, type], BackendArray]
740
+ Template mask data buffer, its shape and datatype.
741
+ template_filter : Union[Tuple[type, tuple of ints, type], BackendArray]
742
+ Template filter data buffer, its shape and datatype.
743
+ ft_target : Union[Tuple[type, tuple of ints, type], BackendArray]
744
+ Fourier transformed target data buffer, its shape and datatype.
745
+ ft_target2 : Union[Tuple[type, tuple of ints, type], BackendArray]
746
+ Fourier transformed squared target data buffer, its shape and datatype.
747
+ ft_target_mask : Union[Tuple[type, tuple of ints, type], BackendArray]
748
+ Fourier transformed target mask data buffer, its shape and datatype.
749
+ fast_shape: tuple of ints
750
+ Data shape for the forward Fourier transform.
751
+ fast_ft_shape: tuple of ints
752
+ Data shape for the inverse Fourier transform.
753
+ rotations : BackendArray
754
+ Rotation matrices to be sampled (n, d, d).
755
+ callback : CallbackClass
756
+ A callable for processing the result of each rotation.
757
+ interpolation_order : int
758
+ Spline order for template rotations.
759
+ overlap_ratio : float, optional
760
+ Required fractional mask overlap, 0.3 by default.
761
+
762
+ References
763
+ ----------
764
+ .. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
765
+ .. [2] https://scikit-image.org/docs/stable/api/skimage.registration.html
766
+ """
767
+ float_dtype, complex_dtype = be._float_dtype, be._complex_dtype
768
+ template = be.from_sharedarr(template)
769
+ target_ft = be.from_sharedarr(ft_target)
770
+ target_ft2 = be.from_sharedarr(ft_target2)
771
+ template_mask = be.from_sharedarr(template_mask)
772
+ target_mask_ft = be.from_sharedarr(ft_target_mask)
773
+ template_filter = be.from_sharedarr(template_filter)
774
+
775
+ axes = tuple(range(template.ndim))
776
+ eps = be.eps(float_dtype)
777
+
778
+ # Allocate score and process specific arrays
779
+ template_rot = be.zeros(fast_shape, float_dtype)
780
+ mask_overlap = be.zeros(fast_shape, float_dtype)
781
+ numerator = be.zeros(fast_shape, float_dtype)
782
+ temp = be.zeros(fast_shape, float_dtype)
783
+ temp2 = be.zeros(fast_shape, float_dtype)
784
+ temp3 = be.zeros(fast_shape, float_dtype)
785
+ temp_ft = be.zeros(fast_ft_shape, complex_dtype)
786
+
787
+ rfftn, irfftn = be.build_fft(
788
+ fast_shape=fast_shape,
789
+ fast_ft_shape=fast_ft_shape,
790
+ real_dtype=float_dtype,
791
+ complex_dtype=complex_dtype,
792
+ temp_real=numerator,
793
+ temp_fft=temp_ft,
794
+ )
795
+
796
+ template_filter_func = _setup_template_filtering(
797
+ forward_ft_shape=fast_shape,
798
+ inverse_ft_shape=fast_ft_shape,
799
+ template_shape=template.shape,
800
+ template_filter=template_filter,
801
+ rfftn=rfftn,
802
+ irfftn=irfftn,
803
+ )
804
+
805
+ callback_func = conditional_execute(callback, callback is not None)
806
+ for index in range(rotations.shape[0]):
807
+ rotation = rotations[index]
808
+ template_rot = be.fill(template_rot, 0)
809
+ temp = be.fill(temp, 0)
810
+ be.rigid_transform(
811
+ arr=template,
812
+ arr_mask=template_mask,
813
+ rotation_matrix=rotation,
814
+ out=template_rot,
815
+ out_mask=temp,
816
+ use_geometric_center=True,
817
+ order=interpolation_order,
818
+ cache=True,
819
+ )
820
+
821
+ template_filter_func(template_rot, temp_ft, template_filter)
822
+ normalize_template(template_rot, temp, be.sum(temp))
823
+
824
+ temp_ft = rfftn(template_rot, temp_ft)
825
+ temp2 = irfftn(target_mask_ft * temp_ft, temp2)
826
+ numerator = irfftn(target_ft * temp_ft, numerator)
827
+
828
+ # temp template_mask_rot | temp_ft template_mask_rot_ft
829
+ # Calculate overlap of masks at every point in the convolution.
830
+ # Locations with high overlap should not be taken into account.
831
+ temp_ft = rfftn(temp, temp_ft)
832
+ mask_overlap = irfftn(temp_ft * target_mask_ft, mask_overlap)
833
+ be.maximum(mask_overlap, eps, out=mask_overlap)
834
+ temp = irfftn(temp_ft * target_ft, temp)
835
+
836
+ be.subtract(
837
+ numerator,
838
+ be.divide(be.multiply(temp, temp2), mask_overlap),
839
+ out=numerator,
840
+ )
841
+
842
+ # temp_3 = fixed_denom
843
+ be.multiply(temp_ft, target_ft2, out=temp_ft)
844
+ temp3 = irfftn(temp_ft, temp3)
845
+ be.subtract(temp3, be.divide(be.square(temp), mask_overlap), out=temp3)
846
+ be.maximum(temp3, 0.0, out=temp3)
847
+
848
+ # temp = moving_denom
849
+ temp_ft = rfftn(be.square(template_rot), temp_ft)
850
+ be.multiply(target_mask_ft, temp_ft, out=temp_ft)
851
+ temp = irfftn(temp_ft, temp)
852
+
853
+ be.subtract(temp, be.divide(be.square(temp2), mask_overlap), out=temp)
854
+ be.maximum(temp, 0.0, out=temp)
855
+
856
+ # temp_2 = denom
857
+ be.multiply(temp3, temp, out=temp)
858
+ be.sqrt(temp, temp2)
859
+
860
+ # Pixels where `denom` is very small will introduce large
861
+ # numbers after division. To get around this problem,
862
+ # we zero-out problematic pixels.
863
+ tol = 1e3 * eps * be.max(be.abs(temp2), axis=axes, keepdims=True)
864
+
865
+ temp2[temp2 < tol] = 1
866
+ temp = be.divide(numerator, temp2, out=temp)
867
+ temp = be.clip(temp, a_min=-1, a_max=1, out=temp)
868
+
869
+ # Apply overlap ratio threshold
870
+ number_px_threshold = overlap_ratio * be.max(
871
+ mask_overlap, axis=axes, keepdims=True
872
+ )
873
+ temp[mask_overlap < number_px_threshold] = 0.0
874
+ callback_func(temp, rotation_matrix=rotation)
875
+
876
+ return callback
877
+
878
+
879
+ MATCHING_EXHAUSTIVE_REGISTER = {
880
+ "CC": (cc_setup, corr_scoring),
881
+ "LCC": (lcc_setup, corr_scoring),
882
+ "CORR": (corr_setup, corr_scoring),
883
+ "CAM": (cam_setup, corr_scoring),
884
+ "FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
885
+ "FLC": (flc_setup, flc_scoring),
886
+ "MCC": (mcc_setup, mcc_scoring),
887
+ }