pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/match_template.py +28 -39
  2. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/postprocess.py +23 -10
  3. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +95 -24
  4. pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
  5. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/METADATA +5 -5
  6. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/RECORD +53 -46
  7. scripts/extract_candidates.py +118 -99
  8. scripts/match_template.py +28 -39
  9. scripts/postprocess.py +23 -10
  10. scripts/preprocessor_gui.py +95 -24
  11. scripts/pytme_runner.py +644 -190
  12. scripts/refine_matches.py +156 -386
  13. tests/data/.DS_Store +0 -0
  14. tests/data/Blurring/.DS_Store +0 -0
  15. tests/data/Maps/.DS_Store +0 -0
  16. tests/data/Raw/.DS_Store +0 -0
  17. tests/data/Structures/.DS_Store +0 -0
  18. tests/preprocessing/test_utils.py +18 -0
  19. tests/test_backends.py +3 -9
  20. tests/test_density.py +0 -1
  21. tests/test_matching_utils.py +10 -60
  22. tests/test_rotations.py +1 -1
  23. tme/__version__.py +1 -1
  24. tme/analyzer/_utils.py +4 -4
  25. tme/analyzer/aggregation.py +13 -3
  26. tme/analyzer/peaks.py +11 -10
  27. tme/backends/_jax_utils.py +15 -13
  28. tme/backends/_numpyfftw_utils.py +270 -0
  29. tme/backends/cupy_backend.py +5 -44
  30. tme/backends/jax_backend.py +58 -37
  31. tme/backends/matching_backend.py +6 -51
  32. tme/backends/mlx_backend.py +1 -27
  33. tme/backends/npfftw_backend.py +68 -65
  34. tme/backends/pytorch_backend.py +1 -26
  35. tme/density.py +2 -6
  36. tme/extensions.cpython-311-darwin.so +0 -0
  37. tme/filters/ctf.py +22 -21
  38. tme/filters/wedge.py +10 -7
  39. tme/mask.py +341 -0
  40. tme/matching_data.py +7 -19
  41. tme/matching_exhaustive.py +34 -47
  42. tme/matching_optimization.py +2 -1
  43. tme/matching_scores.py +206 -411
  44. tme/matching_utils.py +73 -422
  45. tme/memory.py +1 -1
  46. tme/orientations.py +4 -6
  47. tme/rotations.py +1 -1
  48. pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
  49. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +0 -0
  50. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocess.py +0 -0
  51. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
  52. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +0 -0
  53. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/licenses/LICENSE +0 -0
  54. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
tme/matching_scores.py CHANGED
@@ -18,138 +18,9 @@ from .matching_utils import (
18
18
  conditional_execute,
19
19
  identity,
20
20
  normalize_template,
21
- _normalize_template_overflow_safe,
22
21
  )
23
22
 
24
23
 
25
- def _shape_match(shape1: Tuple[int], shape2: Tuple[int]) -> bool:
26
- """
27
- Determine whether ``shape1`` is equal to ``shape2``.
28
-
29
- Parameters
30
- ----------
31
- shape1, shape2 : tuple of ints
32
- Shapes to compare.
33
-
34
- Returns
35
- -------
36
- Bool
37
- ``shape1`` is equal to ``shape2``.
38
- """
39
- if len(shape1) != len(shape2):
40
- return False
41
- return shape1 == shape2
42
-
43
-
44
- def _create_filter_func(
45
- fwd_shape: Tuple[int],
46
- inv_shape: Tuple[int],
47
- arr_shape: Tuple[int],
48
- arr_filter: BackendArray,
49
- arr_ft_shape: Tuple[int],
50
- inv_output_shape: Tuple[int],
51
- real_dtype: type,
52
- cmpl_dtype: type,
53
- fwd_axes=None,
54
- inv_axes=None,
55
- rfftn: Callable = None,
56
- irfftn: Callable = None,
57
- ) -> Callable:
58
- """
59
- Configure template filtering function for Fourier transforms.
60
-
61
- Conceptually we distinguish between three cases. The base case
62
- is that both template and the corresponding filter have the same
63
- shape. Padding is used when the template filter is larger than
64
- the template, for instance to better resolve Fourier filters. Finally
65
- this function also handles the case when a filter is supposed to be
66
- broadcasted over the template batch dimension.
67
-
68
- Parameters
69
- ----------
70
- fwd_shape : tuple of ints
71
- Input shape of rfftn.
72
- inv_shape : tuple of ints
73
- Input shape of irfftn.
74
- arr_shape : tuple of ints
75
- Shape of the array to be filtered.
76
- arr_ft_shape : tuple of ints
77
- Shape of the Fourier transform of the array.
78
- arr_filter : BackendArray
79
- Precomputed filter to apply in the frequency domain.
80
- rfftn : Callable, optional
81
- Foward Fourier transform.
82
- irfftn : Callable, optional
83
- Inverse Fourier transform.
84
-
85
- Returns
86
- -------
87
- Callable
88
- Filter function with parameters template, ft_temp and template_filter.
89
- """
90
- if be.size(arr_filter) == 1:
91
- return conditional_execute(identity, identity, False)
92
-
93
- filter_shape = tuple(int(x) for x in arr_filter.shape)
94
- try:
95
- product_ft_shape = np.broadcast_shapes(arr_ft_shape, filter_shape)
96
- except ValueError:
97
- product_ft_shape, inv_output_shape = filter_shape, arr_shape
98
-
99
- rfft_valid = _shape_match(arr_shape, fwd_shape)
100
- rfft_valid = rfft_valid and _shape_match(product_ft_shape, inv_shape)
101
- rfft_valid = rfft_valid and rfftn is not None and irfftn is not None
102
-
103
- # FTTs were not or built for the wrong shape
104
- if not rfft_valid:
105
- _fwd_shape = arr_shape
106
- if all(x > y for x, y in zip(arr_shape, product_ft_shape)):
107
- _fwd_shape = fwd_shape
108
-
109
- rfftn, irfftn = be.build_fft(
110
- fwd_shape=_fwd_shape,
111
- inv_shape=product_ft_shape,
112
- real_dtype=real_dtype,
113
- cmpl_dtype=cmpl_dtype,
114
- inv_output_shape=inv_output_shape,
115
- fwd_axes=fwd_axes,
116
- inv_axes=inv_axes,
117
- )
118
-
119
- # Default case, all shapes are correctly matched
120
- def _apply_filter(template, ft_temp, template_filter):
121
- ft_temp = rfftn(template, ft_temp)
122
- ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
123
- return irfftn(ft_temp, template)
124
-
125
- if not _shape_match(arr_ft_shape, filter_shape):
126
- real_subset = tuple(slice(0, x) for x in arr_shape)
127
- _template = be.zeros(arr_shape, be._float_dtype)
128
- _ft_temp = be.zeros(product_ft_shape, be._complex_dtype)
129
-
130
- # Arr is padded, filter is not
131
- def _apply_filter_subset(template, ft_temp, template_filter):
132
- # TODO: Benchmark this
133
- _template[:] = template[real_subset]
134
- template[real_subset] = _apply_filter(_template, _ft_temp, template_filter)
135
- return template
136
-
137
- # Filter application requires a broadcasting operation
138
- def _apply_filter_broadcast(template, ft_temp, template_filter):
139
- _ft_prod = rfftn(template, _ft_temp2)
140
- _ft_res = be.multiply(_ft_prod, template_filter, out=_ft_temp)
141
- return irfftn(_ft_res, _template)
142
-
143
- if any(x > y and y == 1 for x, y in zip(filter_shape, arr_ft_shape)):
144
- _template = be.zeros(inv_output_shape, be._float_dtype)
145
- _ft_temp2 = be.zeros((1, *product_ft_shape[1:]), be._complex_dtype)
146
- return _apply_filter_broadcast
147
-
148
- return _apply_filter_subset
149
-
150
- return _apply_filter
151
-
152
-
153
24
  def cc_setup(
154
25
  matching_data: type,
155
26
  fast_shape: Tuple[int],
@@ -176,8 +47,6 @@ def cc_setup(
176
47
  axes = matching_data._batch_axis(matching_data._batch_mask)
177
48
 
178
49
  ret = {
179
- "fast_shape": fast_shape,
180
- "fast_ft_shape": fast_ft_shape,
181
50
  "template": be.to_sharedarr(matching_data.template, shm_handler),
182
51
  "ft_target": be.to_sharedarr(be.rfftn(target_pad, axes=axes), shm_handler),
183
52
  "inv_denominator": be.to_sharedarr(
@@ -219,8 +88,8 @@ def lcc_setup(matching_data, **kwargs) -> Dict:
219
88
  for subset in subsets:
220
89
  template[subset] = laplace(template[subset], mode="wrap")
221
90
 
222
- matching_data._target = target
223
- matching_data._template = template
91
+ matching_data._target = be.to_backend_array(target)
92
+ matching_data._template = be.to_backend_array(template)
224
93
 
225
94
  return cc_setup(matching_data=matching_data, **kwargs)
226
95
 
@@ -316,8 +185,6 @@ def corr_setup(
316
185
  denominator = be.multiply(denominator, mask, out=denominator)
317
186
 
318
187
  ret = {
319
- "fast_shape": fast_shape,
320
- "fast_ft_shape": fast_ft_shape,
321
188
  "template": be.to_sharedarr(template, shm_handler),
322
189
  "ft_target": be.to_sharedarr(ft_target, shm_handler),
323
190
  "inv_denominator": be.to_sharedarr(denominator, shm_handler),
@@ -376,8 +243,6 @@ def flc_setup(
376
243
  ft_target2 = be.rfftn(target_pad, axes=data_axes)
377
244
 
378
245
  ret = {
379
- "fast_shape": fast_shape,
380
- "fast_ft_shape": fast_ft_shape,
381
246
  "template": be.to_sharedarr(matching_data.template, shm_handler),
382
247
  "template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
383
248
  "ft_target": be.to_sharedarr(ft_target, shm_handler),
@@ -438,8 +303,6 @@ def flcSphericalMask_setup(
438
303
 
439
304
  temp2 = be.norm_scores(1, temp2, temp, n_obs, be.eps(be._float_dtype), temp2)
440
305
  ret = {
441
- "fast_shape": fast_shape,
442
- "fast_ft_shape": fast_ft_shape,
443
306
  "template": be.to_sharedarr(matching_data.template, shm_handler),
444
307
  "template_mask": be.to_sharedarr(template_mask, shm_handler),
445
308
  "ft_target": be.to_sharedarr(ft_target, shm_handler),
@@ -461,21 +324,14 @@ def mcc_setup(
461
324
  Setup function for :py:meth:`mcc_scoring`.
462
325
  """
463
326
  target, target_mask = matching_data.target, matching_data.target_mask
464
- target = be.multiply(target, target_mask > 0, out=target)
327
+ target = be.multiply(target, target_mask, out=target)
465
328
 
466
- target = be.topleft_pad(
467
- target,
468
- matching_data._batch_shape(fast_shape, matching_data._template_batch),
469
- )
470
- target_mask = be.topleft_pad(
471
- target_mask,
472
- matching_data._batch_shape(fast_shape, matching_data._template_batch),
473
- )
474
329
  ax = matching_data._batch_axis(matching_data._batch_mask)
330
+ shape = matching_data._batch_shape(fast_shape, matching_data._template_batch)
331
+ target = be.topleft_pad(target, shape)
332
+ target_mask = be.topleft_pad(target_mask, shape)
475
333
 
476
334
  ret = {
477
- "fast_shape": fast_shape,
478
- "fast_ft_shape": fast_ft_shape,
479
335
  "template": be.to_sharedarr(matching_data.template, shm_handler),
480
336
  "template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
481
337
  "ft_target": be.to_sharedarr(be.rfftn(target, axes=ax), shm_handler),
@@ -541,8 +397,7 @@ def corr_scoring(
541
397
 
542
398
  Returns
543
399
  -------
544
- Optional[CallbackClass]
545
- ``callback`` if provided otherwise None.
400
+ CallbackClass
546
401
  """
547
402
  template = be.from_sharedarr(template)
548
403
  ft_target = be.from_sharedarr(ft_target)
@@ -550,58 +405,30 @@ def corr_scoring(
550
405
  numerator = be.from_sharedarr(numerator)
551
406
  template_filter = be.from_sharedarr(template_filter)
552
407
 
553
- norm_func, norm_template, mask_sum = normalize_template, False, 1
408
+ n_obs = None
554
409
  if template_mask is not None:
555
410
  template_mask = be.from_sharedarr(template_mask)
556
- norm_template, mask_sum = True, be.sum(template_mask)
557
- if be.datatype_bytes(template_mask.dtype) == 2:
558
- norm_func = _normalize_template_overflow_safe
559
- mask_sum = be.sum(be.astype(template_mask, be._overflow_safe_dtype))
560
-
561
- callback_func = conditional_execute(callback, callback is not None)
562
- norm_template = conditional_execute(norm_func, norm_template)
563
- norm_numerator = conditional_execute(
564
- be.subtract, identity, _shape_match(numerator.shape, fast_shape)
565
- )
566
- norm_denominator = conditional_execute(
567
- be.multiply, identity, _shape_match(inv_denominator.shape, fast_shape)
568
- )
411
+ n_obs = be.sum(template_mask) if template_mask is not None else None
412
+
413
+ norm_template = conditional_execute(normalize_template, n_obs is not None)
414
+ norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
415
+ norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1))
569
416
 
570
417
  arr = be.zeros(fast_shape, be._float_dtype)
571
418
  ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
572
-
573
- _fftargs = {
574
- "real_dtype": be._float_dtype,
575
- "cmpl_dtype": be._complex_dtype,
576
- "inv_output_shape": fast_shape,
577
- "fwd_axes": None,
578
- "inv_axes": None,
579
- "inv_shape": fast_ft_shape,
580
- "temp_fwd": arr,
581
- }
582
-
583
- _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
584
- rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
585
- _ = _fftargs.pop("temp_fwd", None)
419
+ template_rot = be.zeros(template.shape, be._float_dtype)
586
420
 
587
421
  template_filter_func = _create_filter_func(
588
422
  arr_shape=template.shape,
589
- arr_ft_shape=fast_ft_shape,
590
- arr_filter=template_filter,
591
- rfftn=rfftn,
592
- irfftn=irfftn,
593
- **_fftargs,
423
+ filter_shape=template_filter.shape,
594
424
  )
595
425
 
596
426
  center = be.divide(be.to_backend_array(template.shape) - 1, 2)
597
427
  unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
598
-
599
- template_rot = be.zeros(template.shape, be._float_dtype)
600
428
  for index in range(rotations.shape[0]):
601
- # d+1, d+1 rigid transform matrix from d,d rotation matrix
602
429
  rotation = rotations[index]
603
430
  matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
604
- template_rot, _ = be.rigid_transform(
431
+ _ = be.rigid_transform(
605
432
  arr=template,
606
433
  rotation_matrix=matrix,
607
434
  out=template_rot,
@@ -610,18 +437,17 @@ def corr_scoring(
610
437
  )
611
438
 
612
439
  template_rot = template_filter_func(template_rot, ft_temp, template_filter)
613
- norm_template(template_rot, template_mask, mask_sum)
440
+ norm_template(template_rot, template_mask, n_obs)
614
441
 
615
442
  arr = be.fill(arr, 0)
616
443
  arr[unpadded_slice] = template_rot
617
444
 
618
- ft_temp = rfftn(arr, ft_temp)
619
- ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp)
620
- arr = irfftn(ft_temp, arr)
445
+ ft_temp = be.rfftn(arr, s=fast_shape, out=ft_temp)
446
+ arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
621
447
 
622
- arr = norm_numerator(arr, numerator, out=arr)
623
- arr = norm_denominator(arr, inv_denominator, out=arr)
624
- callback_func(arr, rotation_matrix=rotation)
448
+ arr = norm_sub(arr, numerator, out=arr)
449
+ arr = norm_mul(arr, inv_denominator, out=arr)
450
+ callback(arr, rotation_matrix=rotation)
625
451
 
626
452
  return callback
627
453
 
@@ -682,6 +508,10 @@ def flc_scoring(
682
508
  interpolation_order : int
683
509
  Spline order for template rotations.
684
510
 
511
+ Returns
512
+ -------
513
+ CallbackClass
514
+
685
515
  References
686
516
  ----------
687
517
  .. [1] Hrabe T. et al, J. Struct. Biol. 178, 177 (2012).
@@ -698,63 +528,46 @@ def flc_scoring(
698
528
  temp2 = be.zeros(fast_shape, float_dtype)
699
529
  ft_temp = be.zeros(fast_ft_shape, complex_dtype)
700
530
  ft_denom = be.zeros(fast_ft_shape, complex_dtype)
531
+ template_rot = be.zeros(template.shape, be._float_dtype)
532
+ template_mask_rot = be.zeros(template.shape, be._float_dtype)
701
533
 
702
- _fftargs = {
703
- "real_dtype": be._float_dtype,
704
- "cmpl_dtype": be._complex_dtype,
705
- "inv_output_shape": fast_shape,
706
- "fwd_axes": None,
707
- "inv_axes": None,
708
- "inv_shape": fast_ft_shape,
709
- "temp_fwd": arr,
710
- }
711
-
712
- _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
713
- rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
714
- _ = _fftargs.pop("temp_fwd", None)
715
-
716
- template_filter_func = _create_filter_func(
717
- arr_shape=template.shape,
718
- arr_ft_shape=fast_ft_shape,
719
- arr_filter=template_filter,
720
- rfftn=rfftn,
721
- irfftn=irfftn,
722
- **_fftargs,
723
- )
534
+ tmpl_filter_func = _create_filter_func(template.shape, template_filter.shape)
724
535
 
725
536
  eps = be.eps(float_dtype)
726
- callback_func = conditional_execute(callback, callback is not None)
537
+ center = be.divide(be.to_backend_array(template.shape) - 1, 2)
538
+ unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
727
539
  for index in range(rotations.shape[0]):
728
540
  rotation = rotations[index]
729
- arr = be.fill(arr, 0)
730
- temp = be.fill(temp, 0)
731
- arr, temp = be.rigid_transform(
541
+ matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
542
+ _ = be.rigid_transform(
732
543
  arr=template,
733
544
  arr_mask=template_mask,
734
- rotation_matrix=rotation,
735
- out=arr,
736
- out_mask=temp,
545
+ rotation_matrix=matrix,
546
+ out=template_rot,
547
+ out_mask=template_mask_rot,
737
548
  use_geometric_center=True,
738
549
  order=interpolation_order,
739
550
  cache=True,
740
551
  )
741
552
 
742
- n_obs = be.sum(temp)
743
- arr = template_filter_func(arr, ft_temp, template_filter)
744
- arr = normalize_template(arr, temp, n_obs, axis=None)
553
+ n_obs = be.sum(template_mask_rot)
554
+ template_rot = tmpl_filter_func(template_rot, ft_temp, template_filter)
555
+ template_rot = normalize_template(template_rot, template_mask_rot, n_obs)
556
+
557
+ arr = be.fill(arr, 0)
558
+ temp = be.fill(temp, 0)
559
+ arr[unpadded_slice] = template_rot
560
+ temp[unpadded_slice] = template_mask_rot
745
561
 
746
- ft_temp = rfftn(temp, ft_temp)
747
- ft_denom = be.multiply(ft_target, ft_temp, out=ft_denom)
748
- temp = irfftn(ft_denom, temp)
749
- ft_denom = be.multiply(ft_target2, ft_temp, out=ft_denom)
750
- temp2 = irfftn(ft_denom, temp2)
562
+ ft_temp = be.rfftn(temp, out=ft_temp, s=fast_shape)
563
+ temp = _correlate_fts(ft_target, ft_temp, ft_denom, temp, fast_shape)
564
+ temp2 = _correlate_fts(ft_target2, ft_temp, ft_denom, temp, fast_shape)
751
565
 
752
- ft_temp = rfftn(arr, ft_temp)
753
- ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp)
754
- arr = irfftn(ft_temp, arr)
566
+ ft_temp = be.rfftn(arr, out=ft_temp, s=fast_shape)
567
+ arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
755
568
 
756
569
  arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
757
- callback_func(arr, rotation_matrix=rotation)
570
+ callback(arr, rotation_matrix=rotation)
758
571
 
759
572
  return callback
760
573
 
@@ -821,6 +634,10 @@ def mcc_scoring(
821
634
  overlap_ratio : float, optional
822
635
  Required fractional mask overlap, 0.3 by default.
823
636
 
637
+ Returns
638
+ -------
639
+ CallbackClass
640
+
824
641
  References
825
642
  ----------
826
643
  .. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
@@ -846,30 +663,12 @@ def mcc_scoring(
846
663
  temp3 = be.zeros(fast_shape, float_dtype)
847
664
  temp_ft = be.zeros(fast_ft_shape, complex_dtype)
848
665
 
849
- _fftargs = {
850
- "real_dtype": be._float_dtype,
851
- "cmpl_dtype": be._complex_dtype,
852
- "inv_output_shape": fast_shape,
853
- "fwd_axes": None,
854
- "inv_axes": None,
855
- "inv_shape": fast_ft_shape,
856
- "temp_fwd": temp,
857
- }
858
-
859
- _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
860
- rfftn, irfftn = be.build_fft(temp_inv=temp_ft, **_fftargs)
861
- _ = _fftargs.pop("temp_fwd", None)
862
-
863
666
  template_filter_func = _create_filter_func(
864
667
  arr_shape=template.shape,
865
- arr_ft_shape=fast_ft_shape,
866
- arr_filter=template_filter,
867
- rfftn=rfftn,
868
- irfftn=irfftn,
869
- **_fftargs,
668
+ filter_shape=template_filter.shape,
669
+ arr_padded=True,
870
670
  )
871
671
 
872
- callback_func = conditional_execute(callback, callback is not None)
873
672
  for index in range(rotations.shape[0]):
874
673
  rotation = rotations[index]
875
674
  template_rot = be.fill(template_rot, 0)
@@ -888,17 +687,19 @@ def mcc_scoring(
888
687
  template_filter_func(template_rot, temp_ft, template_filter)
889
688
  normalize_template(template_rot, temp, be.sum(temp))
890
689
 
891
- temp_ft = rfftn(template_rot, temp_ft)
892
- temp2 = irfftn(target_mask_ft * temp_ft, temp2)
893
- numerator = irfftn(target_ft * temp_ft, numerator)
690
+ temp_ft = be.rfftn(template_rot, out=temp_ft, s=fast_shape)
691
+ temp2 = be.irfftn(target_mask_ft * temp_ft, out=temp2, s=fast_shape)
692
+ numerator = be.irfftn(target_ft * temp_ft, out=numerator, s=fast_shape)
894
693
 
895
694
  # temp template_mask_rot | temp_ft template_mask_rot_ft
896
695
  # Calculate overlap of masks at every point in the convolution.
897
696
  # Locations with high overlap should not be taken into account.
898
- temp_ft = rfftn(temp, temp_ft)
899
- mask_overlap = irfftn(temp_ft * target_mask_ft, mask_overlap)
697
+ temp_ft = be.rfftn(temp, out=temp_ft, s=fast_shape)
698
+ mask_overlap = be.irfftn(
699
+ temp_ft * target_mask_ft, out=mask_overlap, s=fast_shape
700
+ )
900
701
  be.maximum(mask_overlap, eps, out=mask_overlap)
901
- temp = irfftn(temp_ft * target_ft, temp)
702
+ temp = be.irfftn(temp_ft * target_ft, out=temp, s=fast_shape)
902
703
 
903
704
  be.subtract(
904
705
  numerator,
@@ -908,21 +709,21 @@ def mcc_scoring(
908
709
 
909
710
  # temp_3 = fixed_denom
910
711
  be.multiply(temp_ft, target_ft2, out=temp_ft)
911
- temp3 = irfftn(temp_ft, temp3)
712
+ temp3 = be.irfftn(temp_ft, out=temp3, s=fast_shape)
912
713
  be.subtract(temp3, be.divide(be.square(temp), mask_overlap), out=temp3)
913
714
  be.maximum(temp3, 0.0, out=temp3)
914
715
 
915
716
  # temp = moving_denom
916
- temp_ft = rfftn(be.square(template_rot), temp_ft)
717
+ temp_ft = be.rfftn(be.square(template_rot), out=temp_ft, s=fast_shape)
917
718
  be.multiply(target_mask_ft, temp_ft, out=temp_ft)
918
- temp = irfftn(temp_ft, temp)
719
+ temp = be.irfftn(temp_ft, out=temp, s=fast_shape)
919
720
 
920
721
  be.subtract(temp, be.divide(be.square(temp2), mask_overlap), out=temp)
921
722
  be.maximum(temp, 0.0, out=temp)
922
723
 
923
724
  # temp_2 = denom
924
725
  be.multiply(temp3, temp, out=temp)
925
- be.sqrt(temp, temp2)
726
+ be.sqrt(temp, out=temp2)
926
727
 
927
728
  # Pixels where `denom` is very small will introduce large
928
729
  # numbers after division. To get around this problem,
@@ -938,29 +739,11 @@ def mcc_scoring(
938
739
  mask_overlap, axis=axes, keepdims=True
939
740
  )
940
741
  temp[mask_overlap < number_px_threshold] = 0.0
941
- callback_func(temp, rotation_matrix=rotation)
742
+ callback(temp, rotation_matrix=rotation)
942
743
 
943
744
  return callback
944
745
 
945
746
 
946
- def _format_slice(shape, squeeze_axis):
947
- ret = tuple(
948
- slice(None) if i not in squeeze_axis else 0 for i, _ in enumerate(shape)
949
- )
950
- return ret
951
-
952
-
953
- def _get_batch_dim(target, template):
954
- target_batch, template_batch = [], []
955
- for i in range(len(target.shape)):
956
- if target.shape[i] == 1 and template.shape[i] != 1:
957
- template_batch.append(i)
958
- if target.shape[i] != 1 and template.shape[i] == 1:
959
- target_batch.append(i)
960
-
961
- return target_batch, template_batch
962
-
963
-
964
747
  def flc_scoring2(
965
748
  template: shm_type,
966
749
  template_mask: shm_type,
@@ -973,25 +756,22 @@ def flc_scoring2(
973
756
  callback: CallbackClass,
974
757
  interpolation_order: int,
975
758
  ) -> CallbackClass:
976
- callback_func = conditional_execute(callback, callback is not None)
977
-
978
- # Retrieve objects from shared memory
979
759
  template = be.from_sharedarr(template)
980
760
  template_mask = be.from_sharedarr(template_mask)
981
761
  ft_target = be.from_sharedarr(ft_target)
982
762
  ft_target2 = be.from_sharedarr(ft_target2)
983
763
  template_filter = be.from_sharedarr(template_filter)
984
764
 
985
- data_axes = None
986
- target_batch, template_batch = _get_batch_dim(ft_target, template)
987
- sqz_cmpl = tuple(1 if i in target_batch else x for i, x in enumerate(fast_ft_shape))
988
- sqz_slice = tuple(slice(0, 1) if x == 1 else slice(None) for x in sqz_cmpl)
765
+ tar_batch, tmpl_batch = _get_batch_dim(ft_target, template)
989
766
 
990
- data_shape = fast_shape
991
- if len(target_batch) or len(template_batch):
992
- batch = (*target_batch, *template_batch)
993
- data_axes = tuple(i for i in range(len(fast_shape)) if i not in batch)
994
- data_shape = tuple(fast_shape[i] for i in data_axes)
767
+ nd = len(fast_shape)
768
+ sqz_slice = tuple(slice(0, 1) if i in tar_batch else slice(None) for i in range(nd))
769
+ tmpl_subset = tuple(0 if i in tar_batch else slice(None) for i in range(nd))
770
+
771
+ axes, shape, batched = None, fast_shape, len(tmpl_batch) > 0
772
+ if len(tar_batch) or len(tmpl_batch):
773
+ axes = tuple(i for i in range(nd) if i not in (*tar_batch, *tmpl_batch))
774
+ shape = tuple(fast_shape[i] for i in axes)
995
775
 
996
776
  arr = be.zeros(fast_shape, be._float_dtype)
997
777
  temp = be.zeros(fast_shape, be._float_dtype)
@@ -999,34 +779,11 @@ def flc_scoring2(
999
779
  ft_denom = be.zeros(fast_ft_shape, be._complex_dtype)
1000
780
 
1001
781
  tmp_sqz, arr_sqz, ft_temp = temp[sqz_slice], arr[sqz_slice], ft_denom[sqz_slice]
1002
- if be.size(template_filter) != 1:
1003
- ret_shape = np.broadcast_shapes(
1004
- sqz_cmpl, tuple(int(x) for x in template_filter.shape)
1005
- )
1006
- ft_temp = be.zeros(ret_shape, be._complex_dtype)
1007
-
1008
- _fftargs = {
1009
- "real_dtype": be._float_dtype,
1010
- "cmpl_dtype": be._complex_dtype,
1011
- "inv_output_shape": fast_shape,
1012
- "fwd_axes": data_axes,
1013
- "inv_axes": data_axes,
1014
- "inv_shape": fast_ft_shape,
1015
- "temp_fwd": arr_sqz if _shape_match(ft_temp.shape, sqz_cmpl) else arr,
1016
- }
1017
-
1018
- # build_fft ignores fwd_shape if temp_fwd is given and serves only for bookkeeping
1019
- _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
1020
- rfftn, irfftn = be.build_fft(temp_inv=ft_denom, **_fftargs)
1021
- _ = _fftargs.pop("temp_fwd", None)
1022
782
 
1023
783
  template_filter_func = _create_filter_func(
1024
784
  arr_shape=template.shape,
1025
- arr_ft_shape=sqz_cmpl,
1026
- arr_filter=template_filter,
1027
- rfftn=rfftn,
1028
- irfftn=irfftn,
1029
- **_fftargs,
785
+ filter_shape=template_filter.shape,
786
+ arr_padded=True,
1030
787
  )
1031
788
 
1032
789
  eps = be.eps(be._float_dtype)
@@ -1034,33 +791,32 @@ def flc_scoring2(
1034
791
  rotation = rotations[index]
1035
792
  be.fill(arr, 0)
1036
793
  be.fill(temp, 0)
1037
- arr_sqz, tmp_sqz = be.rigid_transform(
1038
- arr=template,
1039
- arr_mask=template_mask,
794
+
795
+ _, _ = be.rigid_transform(
796
+ arr=template[tmpl_subset],
797
+ arr_mask=template_mask[tmpl_subset],
1040
798
  rotation_matrix=rotation,
1041
- out=arr_sqz,
1042
- out_mask=tmp_sqz,
799
+ out=arr_sqz[tmpl_subset],
800
+ out_mask=tmp_sqz[tmpl_subset],
1043
801
  use_geometric_center=True,
1044
802
  order=interpolation_order,
1045
- cache=True,
1046
- batched=True,
803
+ cache=False,
804
+ batched=batched,
1047
805
  )
1048
- n_obs = be.sum(tmp_sqz, axis=data_axes, keepdims=True)
806
+
807
+ n_obs = be.sum(tmp_sqz, axis=axes, keepdims=True)
1049
808
  arr_norm = template_filter_func(arr_sqz, ft_temp, template_filter)
1050
- arr_norm = normalize_template(arr_norm, tmp_sqz, n_obs, axis=data_axes)
809
+ arr_norm = normalize_template(arr_norm, tmp_sqz, n_obs, axis=axes)
1051
810
 
1052
- ft_temp = be.rfftn(tmp_sqz, ft_temp, axes=data_axes)
1053
- ft_denom = be.multiply(ft_target, ft_temp, out=ft_denom)
1054
- temp = be.irfftn(ft_denom, temp, axes=data_axes, s=data_shape)
1055
- ft_denom = be.multiply(ft_target2, ft_temp, out=ft_denom)
1056
- temp2 = be.irfftn(ft_denom, temp2, axes=data_axes, s=data_shape)
811
+ ft_temp = be.rfftn(tmp_sqz, out=ft_temp, axes=axes, s=shape)
812
+ temp = _correlate_fts(ft_target, ft_temp, ft_denom, temp, shape, axes)
813
+ temp2 = _correlate_fts(ft_target2, ft_temp, ft_denom, temp2, shape, axes)
1057
814
 
1058
- ft_temp = rfftn(arr_norm, ft_denom)
1059
- ft_denom = be.multiply(ft_target, ft_temp, out=ft_denom)
1060
- arr = irfftn(ft_denom, arr)
815
+ ft_temp = be.rfftn(arr_norm, out=ft_temp, axes=axes, s=shape)
816
+ arr = _correlate_fts(ft_target, ft_temp, ft_denom, arr, shape, axes)
1061
817
 
1062
- be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
1063
- callback_func(arr, rotation_matrix=rotation)
818
+ arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
819
+ callback(arr, rotation_matrix=rotation)
1064
820
 
1065
821
  return callback
1066
822
 
@@ -1085,101 +841,140 @@ def corr_scoring2(
1085
841
  numerator = be.from_sharedarr(numerator)
1086
842
  template_filter = be.from_sharedarr(template_filter)
1087
843
 
1088
- data_axes = None
1089
- target_batch, template_batch = _get_batch_dim(ft_target, template)
1090
- sqz_cmpl = tuple(1 if i in target_batch else x for i, x in enumerate(fast_ft_shape))
1091
- sqz_slice = tuple(slice(0, 1) if x == 1 else slice(None) for x in sqz_cmpl)
1092
- unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
1093
- if len(target_batch) or len(template_batch):
1094
- batch = (*target_batch, *template_batch)
1095
- data_axes = tuple(i for i in range(len(fast_shape)) if i not in batch)
1096
- unpadded_slice = tuple(
1097
- slice(None) if i in batch else slice(0, x)
1098
- for i, x in enumerate(template.shape)
1099
- )
844
+ tar_batch, tmpl_batch = _get_batch_dim(ft_target, template)
845
+
846
+ nd = len(fast_shape)
847
+ sqz_slice = tuple(slice(0, 1) if i in tar_batch else slice(None) for i in range(nd))
848
+ tmpl_subset = tuple(0 if i in tar_batch else slice(None) for i in range(nd))
849
+
850
+ axes, shape, batched = None, fast_shape, len(tmpl_batch) > 0
851
+ if len(tar_batch) or len(tmpl_batch):
852
+ axes = tuple(i for i in range(nd) if i not in (*tar_batch, *tmpl_batch))
853
+ shape = tuple(fast_shape[i] for i in axes)
854
+
855
+ unpadded_slice = tuple(
856
+ slice(None) if i in (*tar_batch, *tmpl_batch) else slice(0, x)
857
+ for i, x in enumerate(template.shape)
858
+ )
1100
859
 
1101
860
  arr = be.zeros(fast_shape, be._float_dtype)
1102
861
  ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
1103
862
  arr_sqz, ft_sqz = arr[sqz_slice], ft_temp[sqz_slice]
1104
863
 
1105
- if be.size(template_filter) != 1:
1106
- # The filter could be w.r.t the unpadded template
1107
- ret_shape = tuple(
1108
- int(x * y) if x == 1 or y == 1 else y
1109
- for x, y in zip(sqz_cmpl, template_filter.shape)
1110
- )
1111
- ft_sqz = be.zeros(ret_shape, be._complex_dtype)
1112
-
1113
- norm_func, norm_template, mask_sum = normalize_template, False, 1
864
+ n_obs = None
1114
865
  if template_mask is not None:
1115
866
  template_mask = be.from_sharedarr(template_mask)
1116
- norm_template, mask_sum = True, be.sum(
1117
- be.astype(template_mask, be._overflow_safe_dtype),
1118
- axis=data_axes,
1119
- keepdims=True,
1120
- )
1121
- if be.datatype_bytes(template_mask.dtype) == 2:
1122
- norm_func = _normalize_template_overflow_safe
1123
-
1124
- callback_func = conditional_execute(callback, callback is not None)
1125
- norm_template = conditional_execute(norm_func, norm_template)
1126
- norm_numerator = conditional_execute(
1127
- be.subtract, identity, _shape_match(numerator.shape, fast_shape)
1128
- )
1129
- norm_denominator = conditional_execute(
1130
- be.multiply, identity, _shape_match(inv_denominator.shape, fast_shape)
1131
- )
867
+ n_obs = be.sum(template_mask, axis=axes, keepdims=True)
1132
868
 
1133
- _fftargs = {
1134
- "real_dtype": be._float_dtype,
1135
- "cmpl_dtype": be._complex_dtype,
1136
- "fwd_axes": data_axes,
1137
- "inv_axes": data_axes,
1138
- "inv_shape": fast_ft_shape,
1139
- "inv_output_shape": fast_shape,
1140
- "temp_fwd": arr_sqz if _shape_match(ft_sqz.shape, sqz_cmpl) else arr,
1141
- }
1142
-
1143
- # build_fft ignores fwd_shape if temp_fwd is given and serves only for bookkeeping
1144
- _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
1145
- rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
1146
- _ = _fftargs.pop("temp_fwd", None)
869
+ norm_template = conditional_execute(normalize_template, n_obs is not None)
870
+ norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
871
+ norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1,))
1147
872
 
1148
873
  template_filter_func = _create_filter_func(
1149
874
  arr_shape=template.shape,
1150
- arr_ft_shape=sqz_cmpl,
1151
- arr_filter=template_filter,
1152
- rfftn=rfftn,
1153
- irfftn=irfftn,
1154
- **_fftargs,
875
+ filter_shape=template_filter.shape,
876
+ arr_padded=True,
1155
877
  )
1156
878
 
1157
879
  for index in range(rotations.shape[0]):
1158
880
  be.fill(arr, 0)
1159
881
  rotation = rotations[index]
1160
- arr_sqz, _ = be.rigid_transform(
1161
- arr=template,
882
+ _, _ = be.rigid_transform(
883
+ arr=template[tmpl_subset],
1162
884
  rotation_matrix=rotation,
1163
- out=arr_sqz,
885
+ out=arr_sqz[tmpl_subset],
1164
886
  use_geometric_center=True,
1165
887
  order=interpolation_order,
1166
- cache=True,
1167
- batched=True,
888
+ cache=False,
889
+ batched=batched,
1168
890
  )
1169
891
  arr_norm = template_filter_func(arr_sqz, ft_sqz, template_filter)
1170
- norm_template(arr_norm[unpadded_slice], template_mask, mask_sum, axis=data_axes)
892
+ norm_template(arr_norm[unpadded_slice], template_mask, n_obs, axis=axes)
1171
893
 
1172
- ft_sqz = rfftn(arr_norm, ft_sqz)
1173
- ft_temp = be.multiply(ft_target, ft_sqz, out=ft_temp)
1174
- arr = irfftn(ft_temp, arr)
894
+ ft_sqz = be.rfftn(arr_norm, out=ft_sqz, axes=axes, s=shape)
895
+ arr = _correlate_fts(ft_target, ft_sqz, ft_temp, arr, shape, axes)
1175
896
 
1176
- arr = norm_numerator(arr, numerator, out=arr)
1177
- arr = norm_denominator(arr, inv_denominator, out=arr)
1178
- callback_func(arr, rotation_matrix=rotation)
897
+ arr = norm_sub(arr, numerator, out=arr)
898
+ arr = norm_mul(arr, inv_denominator, out=arr)
899
+ callback(arr, rotation_matrix=rotation)
1179
900
 
1180
901
  return callback
1181
902
 
1182
903
 
904
+ def _get_batch_dim(target, template):
905
+ target_batch, template_batch = [], []
906
+ for i in range(len(target.shape)):
907
+ if target.shape[i] == 1 and template.shape[i] != 1:
908
+ template_batch.append(i)
909
+ if target.shape[i] != 1 and template.shape[i] == 1:
910
+ target_batch.append(i)
911
+
912
+ return target_batch, template_batch
913
+
914
+
915
+ def _correlate_fts(ft_tar, ft_tmpl, ft_buffer, real_buffer, fast_shape, axes=None):
916
+ ft_buffer = be.multiply(ft_tar, ft_tmpl, out=ft_buffer)
917
+ return be.irfftn(ft_buffer, out=real_buffer, s=fast_shape, axes=axes)
918
+
919
+
920
+ def _create_filter_func(
921
+ arr_shape: Tuple[int],
922
+ filter_shape: BackendArray,
923
+ arr_padded: bool = False,
924
+ axes=None,
925
+ ) -> Callable:
926
+ """
927
+ Configure template filtering function for Fourier transforms.
928
+
929
+ Conceptually we distinguish between three cases. The base case
930
+ is that both template and the corresponding filter have the same
931
+ shape. Padding is used when the template filter is larger than
932
+ the template, for instance to better resolve Fourier filters. Finally
933
+ this function also handles the case when a filter is supposed to be
934
+ broadcasted over the template batch dimension.
935
+
936
+ Parameters
937
+ ----------
938
+ arr_shape : tuple of ints
939
+ Shape of the array to be filtered.
940
+ filter_shape : BackendArray
941
+ Precomputed filter to apply in the frequency domain.
942
+ arr_padded : bool, optional
943
+ Whether the input template is padded and will need to be cropped
944
+ to arr_shape prior to filter applications. Defaults to False.
945
+ axes : tuple of ints, optional
946
+ Axes to perform Fourier transform over.
947
+
948
+ Returns
949
+ -------
950
+ Callable
951
+ Filter function with parameters template, ft_temp and template_filter.
952
+ """
953
+ if filter_shape == (1,):
954
+ return conditional_execute(identity, execute_operation=True)
955
+
956
+ # Default case, all shapes are correctly matched
957
+ def _apply_filter(template, ft_temp, template_filter):
958
+ ft_temp = be.rfftn(template, out=ft_temp, s=template.shape)
959
+ ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
960
+ return be.irfftn(ft_temp, out=template, s=template.shape)
961
+
962
+ if not arr_padded:
963
+ return _apply_filter
964
+
965
+ # Array is padded but filter is w.r.t to the original template
966
+ real_subset = tuple(slice(0, x) for x in arr_shape)
967
+ _template = be.zeros(arr_shape, be._float_dtype)
968
+ _ft_temp = be.zeros(filter_shape, be._complex_dtype)
969
+
970
+ def _apply_filter_subset(template, ft_temp, template_filter):
971
+ _template[:] = template[real_subset]
972
+ template[real_subset] = _apply_filter(_template, _ft_temp, template_filter)
973
+ return template
974
+
975
+ return _apply_filter_subset
976
+
977
+
1183
978
  MATCHING_EXHAUSTIVE_REGISTER = {
1184
979
  "CC": (cc_setup, corr_scoring),
1185
980
  "LCC": (lcc_setup, corr_scoring),
@@ -1188,6 +983,6 @@ MATCHING_EXHAUSTIVE_REGISTER = {
1188
983
  "FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
1189
984
  "FLC": (flc_setup, flc_scoring),
1190
985
  "MCC": (mcc_setup, mcc_scoring),
1191
- "batchFLCSpherical": (flcSphericalMask_setup, corr_scoring2),
986
+ "batchFLCSphericalMask": (flcSphericalMask_setup, corr_scoring2),
1192
987
  "batchFLC": (flc_setup, flc_scoring2),
1193
988
  }