pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +1 -5
  2. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/match_template.py +177 -226
  3. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/postprocess.py +69 -47
  4. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocess.py +10 -23
  5. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +98 -28
  6. pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
  7. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/METADATA +15 -15
  8. pytme-0.3.1.dist-info/RECORD +133 -0
  9. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +1 -0
  10. pytme-0.3.1.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +1 -5
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +118 -99
  14. scripts/match_template.py +177 -226
  15. scripts/match_template_filters.py +1200 -0
  16. scripts/postprocess.py +69 -47
  17. scripts/preprocess.py +10 -23
  18. scripts/preprocessor_gui.py +98 -28
  19. scripts/pytme_runner.py +1223 -0
  20. scripts/refine_matches.py +156 -387
  21. tests/data/.DS_Store +0 -0
  22. tests/data/Blurring/.DS_Store +0 -0
  23. tests/data/Maps/.DS_Store +0 -0
  24. tests/data/Raw/.DS_Store +0 -0
  25. tests/data/Structures/.DS_Store +0 -0
  26. tests/preprocessing/test_frequency_filters.py +19 -10
  27. tests/preprocessing/test_utils.py +18 -0
  28. tests/test_analyzer.py +122 -122
  29. tests/test_backends.py +4 -9
  30. tests/test_density.py +0 -1
  31. tests/test_matching_cli.py +30 -30
  32. tests/test_matching_data.py +5 -5
  33. tests/test_matching_utils.py +11 -61
  34. tests/test_rotations.py +1 -1
  35. tme/__version__.py +1 -1
  36. tme/analyzer/__init__.py +1 -1
  37. tme/analyzer/_utils.py +5 -8
  38. tme/analyzer/aggregation.py +28 -9
  39. tme/analyzer/base.py +25 -36
  40. tme/analyzer/peaks.py +49 -122
  41. tme/analyzer/proxy.py +1 -0
  42. tme/backends/_jax_utils.py +31 -28
  43. tme/backends/_numpyfftw_utils.py +270 -0
  44. tme/backends/cupy_backend.py +11 -54
  45. tme/backends/jax_backend.py +72 -48
  46. tme/backends/matching_backend.py +6 -51
  47. tme/backends/mlx_backend.py +1 -27
  48. tme/backends/npfftw_backend.py +95 -90
  49. tme/backends/pytorch_backend.py +5 -26
  50. tme/density.py +7 -10
  51. tme/extensions.cpython-311-darwin.so +0 -0
  52. tme/filters/__init__.py +2 -2
  53. tme/filters/_utils.py +32 -7
  54. tme/filters/bandpass.py +225 -186
  55. tme/filters/ctf.py +138 -87
  56. tme/filters/reconstruction.py +38 -9
  57. tme/filters/wedge.py +98 -112
  58. tme/filters/whitening.py +1 -6
  59. tme/mask.py +341 -0
  60. tme/matching_data.py +20 -44
  61. tme/matching_exhaustive.py +46 -56
  62. tme/matching_optimization.py +2 -1
  63. tme/matching_scores.py +216 -412
  64. tme/matching_utils.py +82 -424
  65. tme/memory.py +1 -1
  66. tme/orientations.py +16 -8
  67. tme/parser.py +109 -29
  68. tme/preprocessor.py +2 -2
  69. tme/rotations.py +1 -1
  70. pytme-0.3b0.dist-info/RECORD +0 -122
  71. pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
  72. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
  73. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
tme/matching_scores.py CHANGED
@@ -18,138 +18,9 @@ from .matching_utils import (
18
18
  conditional_execute,
19
19
  identity,
20
20
  normalize_template,
21
- _normalize_template_overflow_safe,
22
21
  )
23
22
 
24
23
 
25
- def _shape_match(shape1: Tuple[int], shape2: Tuple[int]) -> bool:
26
- """
27
- Determine whether ``shape1`` is equal to ``shape2``.
28
-
29
- Parameters
30
- ----------
31
- shape1, shape2 : tuple of ints
32
- Shapes to compare.
33
-
34
- Returns
35
- -------
36
- Bool
37
- ``shape1`` is equal to ``shape2``.
38
- """
39
- if len(shape1) != len(shape2):
40
- return False
41
- return shape1 == shape2
42
-
43
-
44
- def _create_filter_func(
45
- fwd_shape: Tuple[int],
46
- inv_shape: Tuple[int],
47
- arr_shape: Tuple[int],
48
- arr_filter: BackendArray,
49
- arr_ft_shape: Tuple[int],
50
- inv_output_shape: Tuple[int],
51
- real_dtype: type,
52
- cmpl_dtype: type,
53
- fwd_axes=None,
54
- inv_axes=None,
55
- rfftn: Callable = None,
56
- irfftn: Callable = None,
57
- ) -> Callable:
58
- """
59
- Configure template filtering function for Fourier transforms.
60
-
61
- Conceptually we distinguish between three cases. The base case
62
- is that both template and the corresponding filter have the same
63
- shape. Padding is used when the template filter is larger than
64
- the template, for instance to better resolve Fourier filters. Finally
65
- this function also handles the case when a filter is supposed to be
66
- broadcasted over the template batch dimension.
67
-
68
- Parameters
69
- ----------
70
- fwd_shape : tuple of ints
71
- Input shape of rfftn.
72
- inv_shape : tuple of ints
73
- Input shape of irfftn.
74
- arr_shape : tuple of ints
75
- Shape of the array to be filtered.
76
- arr_ft_shape : tuple of ints
77
- Shape of the Fourier transform of the array.
78
- arr_filter : BackendArray
79
- Precomputed filter to apply in the frequency domain.
80
- rfftn : Callable, optional
81
- Foward Fourier transform.
82
- irfftn : Callable, optional
83
- Inverse Fourier transform.
84
-
85
- Returns
86
- -------
87
- Callable
88
- Filter function with parameters template, ft_temp and template_filter.
89
- """
90
- if be.size(arr_filter) == 1:
91
- return conditional_execute(identity, identity, False)
92
-
93
- filter_shape = tuple(int(x) for x in arr_filter.shape)
94
- try:
95
- product_ft_shape = np.broadcast_shapes(arr_ft_shape, filter_shape)
96
- except ValueError:
97
- product_ft_shape, inv_output_shape = filter_shape, arr_shape
98
-
99
- rfft_valid = _shape_match(arr_shape, fwd_shape)
100
- rfft_valid = rfft_valid and _shape_match(product_ft_shape, inv_shape)
101
- rfft_valid = rfft_valid and rfftn is not None and irfftn is not None
102
-
103
- # FTTs were not or built for the wrong shape
104
- if not rfft_valid:
105
- _fwd_shape = arr_shape
106
- if all(x > y for x, y in zip(arr_shape, product_ft_shape)):
107
- _fwd_shape = fwd_shape
108
-
109
- rfftn, irfftn = be.build_fft(
110
- fwd_shape=_fwd_shape,
111
- inv_shape=product_ft_shape,
112
- real_dtype=real_dtype,
113
- cmpl_dtype=cmpl_dtype,
114
- inv_output_shape=inv_output_shape,
115
- fwd_axes=fwd_axes,
116
- inv_axes=inv_axes,
117
- )
118
-
119
- # Default case, all shapes are correctly matched
120
- def _apply_filter(template, ft_temp, template_filter):
121
- ft_temp = rfftn(template, ft_temp)
122
- ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
123
- return irfftn(ft_temp, template)
124
-
125
- if not _shape_match(arr_ft_shape, filter_shape):
126
- real_subset = tuple(slice(0, x) for x in arr_shape)
127
- _template = be.zeros(arr_shape, be._float_dtype)
128
- _ft_temp = be.zeros(product_ft_shape, be._complex_dtype)
129
-
130
- # Arr is padded, filter is not
131
- def _apply_filter_subset(template, ft_temp, template_filter):
132
- # TODO: Benchmark this
133
- _template[:] = template[real_subset]
134
- template[real_subset] = _apply_filter(_template, _ft_temp, template_filter)
135
- return template
136
-
137
- # Filter application requires a broadcasting operation
138
- def _apply_filter_broadcast(template, ft_temp, template_filter):
139
- _ft_prod = rfftn(template, _ft_temp2)
140
- _ft_res = be.multiply(_ft_prod, template_filter, out=_ft_temp)
141
- return irfftn(_ft_res, _template)
142
-
143
- if any(x > y and y == 1 for x, y in zip(filter_shape, arr_ft_shape)):
144
- _template = be.zeros(inv_output_shape, be._float_dtype)
145
- _ft_temp2 = be.zeros((1, *product_ft_shape[1:]), be._complex_dtype)
146
- return _apply_filter_broadcast
147
-
148
- return _apply_filter_subset
149
-
150
- return _apply_filter
151
-
152
-
153
24
  def cc_setup(
154
25
  matching_data: type,
155
26
  fast_shape: Tuple[int],
@@ -176,8 +47,6 @@ def cc_setup(
176
47
  axes = matching_data._batch_axis(matching_data._batch_mask)
177
48
 
178
49
  ret = {
179
- "fast_shape": fast_shape,
180
- "fast_ft_shape": fast_ft_shape,
181
50
  "template": be.to_sharedarr(matching_data.template, shm_handler),
182
51
  "ft_target": be.to_sharedarr(be.rfftn(target_pad, axes=axes), shm_handler),
183
52
  "inv_denominator": be.to_sharedarr(
@@ -219,8 +88,8 @@ def lcc_setup(matching_data, **kwargs) -> Dict:
219
88
  for subset in subsets:
220
89
  template[subset] = laplace(template[subset], mode="wrap")
221
90
 
222
- matching_data._target = target
223
- matching_data._template = template
91
+ matching_data._target = be.to_backend_array(target)
92
+ matching_data._template = be.to_backend_array(template)
224
93
 
225
94
  return cc_setup(matching_data=matching_data, **kwargs)
226
95
 
@@ -316,8 +185,6 @@ def corr_setup(
316
185
  denominator = be.multiply(denominator, mask, out=denominator)
317
186
 
318
187
  ret = {
319
- "fast_shape": fast_shape,
320
- "fast_ft_shape": fast_ft_shape,
321
188
  "template": be.to_sharedarr(template, shm_handler),
322
189
  "ft_target": be.to_sharedarr(ft_target, shm_handler),
323
190
  "inv_denominator": be.to_sharedarr(denominator, shm_handler),
@@ -376,8 +243,6 @@ def flc_setup(
376
243
  ft_target2 = be.rfftn(target_pad, axes=data_axes)
377
244
 
378
245
  ret = {
379
- "fast_shape": fast_shape,
380
- "fast_ft_shape": fast_ft_shape,
381
246
  "template": be.to_sharedarr(matching_data.template, shm_handler),
382
247
  "template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
383
248
  "ft_target": be.to_sharedarr(ft_target, shm_handler),
@@ -438,8 +303,6 @@ def flcSphericalMask_setup(
438
303
 
439
304
  temp2 = be.norm_scores(1, temp2, temp, n_obs, be.eps(be._float_dtype), temp2)
440
305
  ret = {
441
- "fast_shape": fast_shape,
442
- "fast_ft_shape": fast_ft_shape,
443
306
  "template": be.to_sharedarr(matching_data.template, shm_handler),
444
307
  "template_mask": be.to_sharedarr(template_mask, shm_handler),
445
308
  "ft_target": be.to_sharedarr(ft_target, shm_handler),
@@ -461,21 +324,14 @@ def mcc_setup(
461
324
  Setup function for :py:meth:`mcc_scoring`.
462
325
  """
463
326
  target, target_mask = matching_data.target, matching_data.target_mask
464
- target = be.multiply(target, target_mask > 0, out=target)
327
+ target = be.multiply(target, target_mask, out=target)
465
328
 
466
- target = be.topleft_pad(
467
- target,
468
- matching_data._batch_shape(fast_shape, matching_data._template_batch),
469
- )
470
- target_mask = be.topleft_pad(
471
- target_mask,
472
- matching_data._batch_shape(fast_shape, matching_data._template_batch),
473
- )
474
329
  ax = matching_data._batch_axis(matching_data._batch_mask)
330
+ shape = matching_data._batch_shape(fast_shape, matching_data._template_batch)
331
+ target = be.topleft_pad(target, shape)
332
+ target_mask = be.topleft_pad(target_mask, shape)
475
333
 
476
334
  ret = {
477
- "fast_shape": fast_shape,
478
- "fast_ft_shape": fast_ft_shape,
479
335
  "template": be.to_sharedarr(matching_data.template, shm_handler),
480
336
  "template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
481
337
  "ft_target": be.to_sharedarr(be.rfftn(target, axes=ax), shm_handler),
@@ -541,8 +397,7 @@ def corr_scoring(
541
397
 
542
398
  Returns
543
399
  -------
544
- Optional[CallbackClass]
545
- ``callback`` if provided otherwise None.
400
+ CallbackClass
546
401
  """
547
402
  template = be.from_sharedarr(template)
548
403
  ft_target = be.from_sharedarr(ft_target)
@@ -550,71 +405,49 @@ def corr_scoring(
550
405
  numerator = be.from_sharedarr(numerator)
551
406
  template_filter = be.from_sharedarr(template_filter)
552
407
 
553
- norm_func, norm_template, mask_sum = normalize_template, False, 1
408
+ n_obs = None
554
409
  if template_mask is not None:
555
410
  template_mask = be.from_sharedarr(template_mask)
556
- norm_template, mask_sum = True, be.sum(template_mask)
557
- if be.datatype_bytes(template_mask.dtype) == 2:
558
- norm_func = _normalize_template_overflow_safe
559
- mask_sum = be.sum(be.astype(template_mask, be._overflow_safe_dtype))
560
-
561
- callback_func = conditional_execute(callback, callback is not None)
562
- norm_template = conditional_execute(norm_func, norm_template)
563
- norm_numerator = conditional_execute(
564
- be.subtract, identity, _shape_match(numerator.shape, fast_shape)
565
- )
566
- norm_denominator = conditional_execute(
567
- be.multiply, identity, _shape_match(inv_denominator.shape, fast_shape)
568
- )
411
+ n_obs = be.sum(template_mask) if template_mask is not None else None
412
+
413
+ norm_template = conditional_execute(normalize_template, n_obs is not None)
414
+ norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
415
+ norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1))
569
416
 
570
417
  arr = be.zeros(fast_shape, be._float_dtype)
571
418
  ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
572
-
573
- _fftargs = {
574
- "real_dtype": be._float_dtype,
575
- "cmpl_dtype": be._complex_dtype,
576
- "inv_output_shape": fast_shape,
577
- "fwd_axes": None,
578
- "inv_axes": None,
579
- "inv_shape": fast_ft_shape,
580
- "temp_fwd": arr,
581
- }
582
-
583
- _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
584
- rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
585
- _ = _fftargs.pop("temp_fwd", None)
419
+ template_rot = be.zeros(template.shape, be._float_dtype)
586
420
 
587
421
  template_filter_func = _create_filter_func(
588
422
  arr_shape=template.shape,
589
- arr_ft_shape=fast_ft_shape,
590
- arr_filter=template_filter,
591
- rfftn=rfftn,
592
- irfftn=irfftn,
593
- **_fftargs,
423
+ filter_shape=template_filter.shape,
594
424
  )
595
425
 
426
+ center = be.divide(be.to_backend_array(template.shape) - 1, 2)
596
427
  unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
597
428
  for index in range(rotations.shape[0]):
598
429
  rotation = rotations[index]
599
- arr = be.fill(arr, 0)
600
- arr, _ = be.rigid_transform(
430
+ matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
431
+ _ = be.rigid_transform(
601
432
  arr=template,
602
- rotation_matrix=rotation,
603
- out=arr,
604
- use_geometric_center=True,
433
+ rotation_matrix=matrix,
434
+ out=template_rot,
605
435
  order=interpolation_order,
606
- cache=False,
436
+ cache=True,
607
437
  )
608
- arr = template_filter_func(arr, ft_temp, template_filter)
609
- norm_template(arr[unpadded_slice], template_mask, mask_sum)
610
438
 
611
- ft_temp = rfftn(arr, ft_temp)
612
- ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp)
613
- arr = irfftn(ft_temp, arr)
439
+ template_rot = template_filter_func(template_rot, ft_temp, template_filter)
440
+ norm_template(template_rot, template_mask, n_obs)
441
+
442
+ arr = be.fill(arr, 0)
443
+ arr[unpadded_slice] = template_rot
614
444
 
615
- arr = norm_numerator(arr, numerator, out=arr)
616
- arr = norm_denominator(arr, inv_denominator, out=arr)
617
- callback_func(arr, rotation_matrix=rotation)
445
+ ft_temp = be.rfftn(arr, s=fast_shape, out=ft_temp)
446
+ arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
447
+
448
+ arr = norm_sub(arr, numerator, out=arr)
449
+ arr = norm_mul(arr, inv_denominator, out=arr)
450
+ callback(arr, rotation_matrix=rotation)
618
451
 
619
452
  return callback
620
453
 
@@ -675,6 +508,10 @@ def flc_scoring(
675
508
  interpolation_order : int
676
509
  Spline order for template rotations.
677
510
 
511
+ Returns
512
+ -------
513
+ CallbackClass
514
+
678
515
  References
679
516
  ----------
680
517
  .. [1] Hrabe T. et al, J. Struct. Biol. 178, 177 (2012).
@@ -691,63 +528,46 @@ def flc_scoring(
691
528
  temp2 = be.zeros(fast_shape, float_dtype)
692
529
  ft_temp = be.zeros(fast_ft_shape, complex_dtype)
693
530
  ft_denom = be.zeros(fast_ft_shape, complex_dtype)
531
+ template_rot = be.zeros(template.shape, be._float_dtype)
532
+ template_mask_rot = be.zeros(template.shape, be._float_dtype)
694
533
 
695
- _fftargs = {
696
- "real_dtype": be._float_dtype,
697
- "cmpl_dtype": be._complex_dtype,
698
- "inv_output_shape": fast_shape,
699
- "fwd_axes": None,
700
- "inv_axes": None,
701
- "inv_shape": fast_ft_shape,
702
- "temp_fwd": arr,
703
- }
704
-
705
- _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
706
- rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
707
- _ = _fftargs.pop("temp_fwd", None)
708
-
709
- template_filter_func = _create_filter_func(
710
- arr_shape=template.shape,
711
- arr_ft_shape=fast_ft_shape,
712
- arr_filter=template_filter,
713
- rfftn=rfftn,
714
- irfftn=irfftn,
715
- **_fftargs,
716
- )
534
+ tmpl_filter_func = _create_filter_func(template.shape, template_filter.shape)
717
535
 
718
536
  eps = be.eps(float_dtype)
719
- callback_func = conditional_execute(callback, callback is not None)
537
+ center = be.divide(be.to_backend_array(template.shape) - 1, 2)
538
+ unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
720
539
  for index in range(rotations.shape[0]):
721
540
  rotation = rotations[index]
722
- arr = be.fill(arr, 0)
723
- temp = be.fill(temp, 0)
724
- arr, temp = be.rigid_transform(
541
+ matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
542
+ _ = be.rigid_transform(
725
543
  arr=template,
726
544
  arr_mask=template_mask,
727
- rotation_matrix=rotation,
728
- out=arr,
729
- out_mask=temp,
545
+ rotation_matrix=matrix,
546
+ out=template_rot,
547
+ out_mask=template_mask_rot,
730
548
  use_geometric_center=True,
731
549
  order=interpolation_order,
732
- cache=False,
550
+ cache=True,
733
551
  )
734
552
 
735
- n_obs = be.sum(temp)
736
- arr = template_filter_func(arr, ft_temp, template_filter)
737
- arr = normalize_template(arr, temp, n_obs, axis=None)
553
+ n_obs = be.sum(template_mask_rot)
554
+ template_rot = tmpl_filter_func(template_rot, ft_temp, template_filter)
555
+ template_rot = normalize_template(template_rot, template_mask_rot, n_obs)
556
+
557
+ arr = be.fill(arr, 0)
558
+ temp = be.fill(temp, 0)
559
+ arr[unpadded_slice] = template_rot
560
+ temp[unpadded_slice] = template_mask_rot
738
561
 
739
- ft_temp = rfftn(temp, ft_temp)
740
- ft_denom = be.multiply(ft_target, ft_temp, out=ft_denom)
741
- temp = irfftn(ft_denom, temp)
742
- ft_denom = be.multiply(ft_target2, ft_temp, out=ft_denom)
743
- temp2 = irfftn(ft_denom, temp2)
562
+ ft_temp = be.rfftn(temp, out=ft_temp, s=fast_shape)
563
+ temp = _correlate_fts(ft_target, ft_temp, ft_denom, temp, fast_shape)
564
+ temp2 = _correlate_fts(ft_target2, ft_temp, ft_denom, temp, fast_shape)
744
565
 
745
- ft_temp = rfftn(arr, ft_temp)
746
- ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp)
747
- arr = irfftn(ft_temp, arr)
566
+ ft_temp = be.rfftn(arr, out=ft_temp, s=fast_shape)
567
+ arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
748
568
 
749
569
  arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
750
- callback_func(arr, rotation_matrix=rotation)
570
+ callback(arr, rotation_matrix=rotation)
751
571
 
752
572
  return callback
753
573
 
@@ -814,6 +634,10 @@ def mcc_scoring(
814
634
  overlap_ratio : float, optional
815
635
  Required fractional mask overlap, 0.3 by default.
816
636
 
637
+ Returns
638
+ -------
639
+ CallbackClass
640
+
817
641
  References
818
642
  ----------
819
643
  .. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
@@ -839,30 +663,12 @@ def mcc_scoring(
839
663
  temp3 = be.zeros(fast_shape, float_dtype)
840
664
  temp_ft = be.zeros(fast_ft_shape, complex_dtype)
841
665
 
842
- _fftargs = {
843
- "real_dtype": be._float_dtype,
844
- "cmpl_dtype": be._complex_dtype,
845
- "inv_output_shape": fast_shape,
846
- "fwd_axes": None,
847
- "inv_axes": None,
848
- "inv_shape": fast_ft_shape,
849
- "temp_fwd": temp,
850
- }
851
-
852
- _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
853
- rfftn, irfftn = be.build_fft(temp_inv=temp_ft, **_fftargs)
854
- _ = _fftargs.pop("temp_fwd", None)
855
-
856
666
  template_filter_func = _create_filter_func(
857
667
  arr_shape=template.shape,
858
- arr_ft_shape=fast_ft_shape,
859
- arr_filter=template_filter,
860
- rfftn=rfftn,
861
- irfftn=irfftn,
862
- **_fftargs,
668
+ filter_shape=template_filter.shape,
669
+ arr_padded=True,
863
670
  )
864
671
 
865
- callback_func = conditional_execute(callback, callback is not None)
866
672
  for index in range(rotations.shape[0]):
867
673
  rotation = rotations[index]
868
674
  template_rot = be.fill(template_rot, 0)
@@ -875,23 +681,25 @@ def mcc_scoring(
875
681
  out_mask=temp,
876
682
  use_geometric_center=True,
877
683
  order=interpolation_order,
878
- cache=False,
684
+ cache=True,
879
685
  )
880
686
 
881
687
  template_filter_func(template_rot, temp_ft, template_filter)
882
688
  normalize_template(template_rot, temp, be.sum(temp))
883
689
 
884
- temp_ft = rfftn(template_rot, temp_ft)
885
- temp2 = irfftn(target_mask_ft * temp_ft, temp2)
886
- numerator = irfftn(target_ft * temp_ft, numerator)
690
+ temp_ft = be.rfftn(template_rot, out=temp_ft, s=fast_shape)
691
+ temp2 = be.irfftn(target_mask_ft * temp_ft, out=temp2, s=fast_shape)
692
+ numerator = be.irfftn(target_ft * temp_ft, out=numerator, s=fast_shape)
887
693
 
888
694
  # temp template_mask_rot | temp_ft template_mask_rot_ft
889
695
  # Calculate overlap of masks at every point in the convolution.
890
696
  # Locations with high overlap should not be taken into account.
891
- temp_ft = rfftn(temp, temp_ft)
892
- mask_overlap = irfftn(temp_ft * target_mask_ft, mask_overlap)
697
+ temp_ft = be.rfftn(temp, out=temp_ft, s=fast_shape)
698
+ mask_overlap = be.irfftn(
699
+ temp_ft * target_mask_ft, out=mask_overlap, s=fast_shape
700
+ )
893
701
  be.maximum(mask_overlap, eps, out=mask_overlap)
894
- temp = irfftn(temp_ft * target_ft, temp)
702
+ temp = be.irfftn(temp_ft * target_ft, out=temp, s=fast_shape)
895
703
 
896
704
  be.subtract(
897
705
  numerator,
@@ -901,21 +709,21 @@ def mcc_scoring(
901
709
 
902
710
  # temp_3 = fixed_denom
903
711
  be.multiply(temp_ft, target_ft2, out=temp_ft)
904
- temp3 = irfftn(temp_ft, temp3)
712
+ temp3 = be.irfftn(temp_ft, out=temp3, s=fast_shape)
905
713
  be.subtract(temp3, be.divide(be.square(temp), mask_overlap), out=temp3)
906
714
  be.maximum(temp3, 0.0, out=temp3)
907
715
 
908
716
  # temp = moving_denom
909
- temp_ft = rfftn(be.square(template_rot), temp_ft)
717
+ temp_ft = be.rfftn(be.square(template_rot), out=temp_ft, s=fast_shape)
910
718
  be.multiply(target_mask_ft, temp_ft, out=temp_ft)
911
- temp = irfftn(temp_ft, temp)
719
+ temp = be.irfftn(temp_ft, out=temp, s=fast_shape)
912
720
 
913
721
  be.subtract(temp, be.divide(be.square(temp2), mask_overlap), out=temp)
914
722
  be.maximum(temp, 0.0, out=temp)
915
723
 
916
724
  # temp_2 = denom
917
725
  be.multiply(temp3, temp, out=temp)
918
- be.sqrt(temp, temp2)
726
+ be.sqrt(temp, out=temp2)
919
727
 
920
728
  # Pixels where `denom` is very small will introduce large
921
729
  # numbers after division. To get around this problem,
@@ -931,29 +739,11 @@ def mcc_scoring(
931
739
  mask_overlap, axis=axes, keepdims=True
932
740
  )
933
741
  temp[mask_overlap < number_px_threshold] = 0.0
934
- callback_func(temp, rotation_matrix=rotation)
742
+ callback(temp, rotation_matrix=rotation)
935
743
 
936
744
  return callback
937
745
 
938
746
 
939
- def _format_slice(shape, squeeze_axis):
940
- ret = tuple(
941
- slice(None) if i not in squeeze_axis else 0 for i, _ in enumerate(shape)
942
- )
943
- return ret
944
-
945
-
946
- def _get_batch_dim(target, template):
947
- target_batch, template_batch = [], []
948
- for i in range(len(target.shape)):
949
- if target.shape[i] == 1 and template.shape[i] != 1:
950
- template_batch.append(i)
951
- if target.shape[i] != 1 and template.shape[i] == 1:
952
- target_batch.append(i)
953
-
954
- return target_batch, template_batch
955
-
956
-
957
747
  def flc_scoring2(
958
748
  template: shm_type,
959
749
  template_mask: shm_type,
@@ -966,25 +756,22 @@ def flc_scoring2(
966
756
  callback: CallbackClass,
967
757
  interpolation_order: int,
968
758
  ) -> CallbackClass:
969
- callback_func = conditional_execute(callback, callback is not None)
970
-
971
- # Retrieve objects from shared memory
972
759
  template = be.from_sharedarr(template)
973
760
  template_mask = be.from_sharedarr(template_mask)
974
761
  ft_target = be.from_sharedarr(ft_target)
975
762
  ft_target2 = be.from_sharedarr(ft_target2)
976
763
  template_filter = be.from_sharedarr(template_filter)
977
764
 
978
- data_axes = None
979
- target_batch, template_batch = _get_batch_dim(ft_target, template)
980
- sqz_cmpl = tuple(1 if i in target_batch else x for i, x in enumerate(fast_ft_shape))
981
- sqz_slice = tuple(slice(0, 1) if x == 1 else slice(None) for x in sqz_cmpl)
765
+ tar_batch, tmpl_batch = _get_batch_dim(ft_target, template)
766
+
767
+ nd = len(fast_shape)
768
+ sqz_slice = tuple(slice(0, 1) if i in tar_batch else slice(None) for i in range(nd))
769
+ tmpl_subset = tuple(0 if i in tar_batch else slice(None) for i in range(nd))
982
770
 
983
- data_shape = fast_shape
984
- if len(target_batch) or len(template_batch):
985
- batch = (*target_batch, *template_batch)
986
- data_axes = tuple(i for i in range(len(fast_shape)) if i not in batch)
987
- data_shape = tuple(fast_shape[i] for i in data_axes)
771
+ axes, shape, batched = None, fast_shape, len(tmpl_batch) > 0
772
+ if len(tar_batch) or len(tmpl_batch):
773
+ axes = tuple(i for i in range(nd) if i not in (*tar_batch, *tmpl_batch))
774
+ shape = tuple(fast_shape[i] for i in axes)
988
775
 
989
776
  arr = be.zeros(fast_shape, be._float_dtype)
990
777
  temp = be.zeros(fast_shape, be._float_dtype)
@@ -992,34 +779,11 @@ def flc_scoring2(
992
779
  ft_denom = be.zeros(fast_ft_shape, be._complex_dtype)
993
780
 
994
781
  tmp_sqz, arr_sqz, ft_temp = temp[sqz_slice], arr[sqz_slice], ft_denom[sqz_slice]
995
- if be.size(template_filter) != 1:
996
- ret_shape = np.broadcast_shapes(
997
- sqz_cmpl, tuple(int(x) for x in template_filter.shape)
998
- )
999
- ft_temp = be.zeros(ret_shape, be._complex_dtype)
1000
-
1001
- _fftargs = {
1002
- "real_dtype": be._float_dtype,
1003
- "cmpl_dtype": be._complex_dtype,
1004
- "inv_output_shape": fast_shape,
1005
- "fwd_axes": data_axes,
1006
- "inv_axes": data_axes,
1007
- "inv_shape": fast_ft_shape,
1008
- "temp_fwd": arr_sqz if _shape_match(ft_temp.shape, sqz_cmpl) else arr,
1009
- }
1010
-
1011
- # build_fft ignores fwd_shape if temp_fwd is given and serves only for bookkeeping
1012
- _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
1013
- rfftn, irfftn = be.build_fft(temp_inv=ft_denom, **_fftargs)
1014
- _ = _fftargs.pop("temp_fwd", None)
1015
782
 
1016
783
  template_filter_func = _create_filter_func(
1017
784
  arr_shape=template.shape,
1018
- arr_ft_shape=sqz_cmpl,
1019
- arr_filter=template_filter,
1020
- rfftn=rfftn,
1021
- irfftn=irfftn,
1022
- **_fftargs,
785
+ filter_shape=template_filter.shape,
786
+ arr_padded=True,
1023
787
  )
1024
788
 
1025
789
  eps = be.eps(be._float_dtype)
@@ -1027,32 +791,32 @@ def flc_scoring2(
1027
791
  rotation = rotations[index]
1028
792
  be.fill(arr, 0)
1029
793
  be.fill(temp, 0)
1030
- arr_sqz, tmp_sqz = be.rigid_transform(
1031
- arr=template,
1032
- arr_mask=template_mask,
794
+
795
+ _, _ = be.rigid_transform(
796
+ arr=template[tmpl_subset],
797
+ arr_mask=template_mask[tmpl_subset],
1033
798
  rotation_matrix=rotation,
1034
- out=arr_sqz,
1035
- out_mask=tmp_sqz,
799
+ out=arr_sqz[tmpl_subset],
800
+ out_mask=tmp_sqz[tmpl_subset],
1036
801
  use_geometric_center=True,
1037
802
  order=interpolation_order,
1038
803
  cache=False,
804
+ batched=batched,
1039
805
  )
1040
- n_obs = be.sum(tmp_sqz, axis=data_axes, keepdims=True)
806
+
807
+ n_obs = be.sum(tmp_sqz, axis=axes, keepdims=True)
1041
808
  arr_norm = template_filter_func(arr_sqz, ft_temp, template_filter)
1042
- arr_norm = normalize_template(arr_norm, tmp_sqz, n_obs, axis=data_axes)
809
+ arr_norm = normalize_template(arr_norm, tmp_sqz, n_obs, axis=axes)
1043
810
 
1044
- ft_temp = be.rfftn(tmp_sqz, ft_temp, axes=data_axes)
1045
- ft_denom = be.multiply(ft_target, ft_temp, out=ft_denom)
1046
- temp = be.irfftn(ft_denom, temp, axes=data_axes, s=data_shape)
1047
- ft_denom = be.multiply(ft_target2, ft_temp, out=ft_denom)
1048
- temp2 = be.irfftn(ft_denom, temp2, axes=data_axes, s=data_shape)
811
+ ft_temp = be.rfftn(tmp_sqz, out=ft_temp, axes=axes, s=shape)
812
+ temp = _correlate_fts(ft_target, ft_temp, ft_denom, temp, shape, axes)
813
+ temp2 = _correlate_fts(ft_target2, ft_temp, ft_denom, temp2, shape, axes)
1049
814
 
1050
- ft_temp = rfftn(arr_norm, ft_denom)
1051
- ft_denom = be.multiply(ft_target, ft_temp, out=ft_denom)
1052
- arr = irfftn(ft_denom, arr)
815
+ ft_temp = be.rfftn(arr_norm, out=ft_temp, axes=axes, s=shape)
816
+ arr = _correlate_fts(ft_target, ft_temp, ft_denom, arr, shape, axes)
1053
817
 
1054
- be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
1055
- callback_func(arr, rotation_matrix=rotation)
818
+ arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
819
+ callback(arr, rotation_matrix=rotation)
1056
820
 
1057
821
  return callback
1058
822
 
@@ -1077,100 +841,140 @@ def corr_scoring2(
1077
841
  numerator = be.from_sharedarr(numerator)
1078
842
  template_filter = be.from_sharedarr(template_filter)
1079
843
 
1080
- data_axes = None
1081
- target_batch, template_batch = _get_batch_dim(ft_target, template)
1082
- sqz_cmpl = tuple(1 if i in target_batch else x for i, x in enumerate(fast_ft_shape))
1083
- sqz_slice = tuple(slice(0, 1) if x == 1 else slice(None) for x in sqz_cmpl)
1084
- unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
1085
- if len(target_batch) or len(template_batch):
1086
- batch = (*target_batch, *template_batch)
1087
- data_axes = tuple(i for i in range(len(fast_shape)) if i not in batch)
1088
- unpadded_slice = tuple(
1089
- slice(None) if i in batch else slice(0, x)
1090
- for i, x in enumerate(template.shape)
1091
- )
844
+ tar_batch, tmpl_batch = _get_batch_dim(ft_target, template)
845
+
846
+ nd = len(fast_shape)
847
+ sqz_slice = tuple(slice(0, 1) if i in tar_batch else slice(None) for i in range(nd))
848
+ tmpl_subset = tuple(0 if i in tar_batch else slice(None) for i in range(nd))
849
+
850
+ axes, shape, batched = None, fast_shape, len(tmpl_batch) > 0
851
+ if len(tar_batch) or len(tmpl_batch):
852
+ axes = tuple(i for i in range(nd) if i not in (*tar_batch, *tmpl_batch))
853
+ shape = tuple(fast_shape[i] for i in axes)
854
+
855
+ unpadded_slice = tuple(
856
+ slice(None) if i in (*tar_batch, *tmpl_batch) else slice(0, x)
857
+ for i, x in enumerate(template.shape)
858
+ )
1092
859
 
1093
860
  arr = be.zeros(fast_shape, be._float_dtype)
1094
861
  ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
1095
862
  arr_sqz, ft_sqz = arr[sqz_slice], ft_temp[sqz_slice]
1096
863
 
1097
- if be.size(template_filter) != 1:
1098
- # The filter could be w.r.t the unpadded template
1099
- ret_shape = tuple(
1100
- int(x * y) if x == 1 or y == 1 else y
1101
- for x, y in zip(sqz_cmpl, template_filter.shape)
1102
- )
1103
- ft_sqz = be.zeros(ret_shape, be._complex_dtype)
1104
-
1105
- norm_func, norm_template, mask_sum = normalize_template, False, 1
864
+ n_obs = None
1106
865
  if template_mask is not None:
1107
866
  template_mask = be.from_sharedarr(template_mask)
1108
- norm_template, mask_sum = True, be.sum(
1109
- be.astype(template_mask, be._overflow_safe_dtype),
1110
- axis=data_axes,
1111
- keepdims=True,
1112
- )
1113
- if be.datatype_bytes(template_mask.dtype) == 2:
1114
- norm_func = _normalize_template_overflow_safe
867
+ n_obs = be.sum(template_mask, axis=axes, keepdims=True)
1115
868
 
1116
- callback_func = conditional_execute(callback, callback is not None)
1117
- norm_template = conditional_execute(norm_func, norm_template)
1118
- norm_numerator = conditional_execute(
1119
- be.subtract, identity, _shape_match(numerator.shape, fast_shape)
1120
- )
1121
- norm_denominator = conditional_execute(
1122
- be.multiply, identity, _shape_match(inv_denominator.shape, fast_shape)
1123
- )
1124
-
1125
- _fftargs = {
1126
- "real_dtype": be._float_dtype,
1127
- "cmpl_dtype": be._complex_dtype,
1128
- "fwd_axes": data_axes,
1129
- "inv_axes": data_axes,
1130
- "inv_shape": fast_ft_shape,
1131
- "inv_output_shape": fast_shape,
1132
- "temp_fwd": arr_sqz if _shape_match(ft_sqz.shape, sqz_cmpl) else arr,
1133
- }
1134
-
1135
- # build_fft ignores fwd_shape if temp_fwd is given and serves only for bookkeeping
1136
- _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
1137
- rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
1138
- _ = _fftargs.pop("temp_fwd", None)
869
+ norm_template = conditional_execute(normalize_template, n_obs is not None)
870
+ norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
871
+ norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1,))
1139
872
 
1140
873
  template_filter_func = _create_filter_func(
1141
874
  arr_shape=template.shape,
1142
- arr_ft_shape=sqz_cmpl,
1143
- arr_filter=template_filter,
1144
- rfftn=rfftn,
1145
- irfftn=irfftn,
1146
- **_fftargs,
875
+ filter_shape=template_filter.shape,
876
+ arr_padded=True,
1147
877
  )
1148
878
 
1149
879
  for index in range(rotations.shape[0]):
1150
880
  be.fill(arr, 0)
1151
881
  rotation = rotations[index]
1152
- arr_sqz, _ = be.rigid_transform(
1153
- arr=template,
882
+ _, _ = be.rigid_transform(
883
+ arr=template[tmpl_subset],
1154
884
  rotation_matrix=rotation,
1155
- out=arr_sqz,
885
+ out=arr_sqz[tmpl_subset],
1156
886
  use_geometric_center=True,
1157
887
  order=interpolation_order,
1158
888
  cache=False,
889
+ batched=batched,
1159
890
  )
1160
891
  arr_norm = template_filter_func(arr_sqz, ft_sqz, template_filter)
1161
- norm_template(arr_norm[unpadded_slice], template_mask, mask_sum, axis=data_axes)
892
+ norm_template(arr_norm[unpadded_slice], template_mask, n_obs, axis=axes)
1162
893
 
1163
- ft_sqz = rfftn(arr_norm, ft_sqz)
1164
- ft_temp = be.multiply(ft_target, ft_sqz, out=ft_temp)
1165
- arr = irfftn(ft_temp, arr)
894
+ ft_sqz = be.rfftn(arr_norm, out=ft_sqz, axes=axes, s=shape)
895
+ arr = _correlate_fts(ft_target, ft_sqz, ft_temp, arr, shape, axes)
1166
896
 
1167
- arr = norm_numerator(arr, numerator, out=arr)
1168
- arr = norm_denominator(arr, inv_denominator, out=arr)
1169
- callback_func(arr, rotation_matrix=rotation)
897
+ arr = norm_sub(arr, numerator, out=arr)
898
+ arr = norm_mul(arr, inv_denominator, out=arr)
899
+ callback(arr, rotation_matrix=rotation)
1170
900
 
1171
901
  return callback
1172
902
 
1173
903
 
904
+ def _get_batch_dim(target, template):
905
+ target_batch, template_batch = [], []
906
+ for i in range(len(target.shape)):
907
+ if target.shape[i] == 1 and template.shape[i] != 1:
908
+ template_batch.append(i)
909
+ if target.shape[i] != 1 and template.shape[i] == 1:
910
+ target_batch.append(i)
911
+
912
+ return target_batch, template_batch
913
+
914
+
915
+ def _correlate_fts(ft_tar, ft_tmpl, ft_buffer, real_buffer, fast_shape, axes=None):
916
+ ft_buffer = be.multiply(ft_tar, ft_tmpl, out=ft_buffer)
917
+ return be.irfftn(ft_buffer, out=real_buffer, s=fast_shape, axes=axes)
918
+
919
+
920
+ def _create_filter_func(
921
+ arr_shape: Tuple[int],
922
+ filter_shape: BackendArray,
923
+ arr_padded: bool = False,
924
+ axes=None,
925
+ ) -> Callable:
926
+ """
927
+ Configure template filtering function for Fourier transforms.
928
+
929
+ Conceptually we distinguish between three cases. The base case
930
+ is that both template and the corresponding filter have the same
931
+ shape. Padding is used when the template filter is larger than
932
+ the template, for instance to better resolve Fourier filters. Finally
933
+ this function also handles the case when a filter is supposed to be
934
+ broadcasted over the template batch dimension.
935
+
936
+ Parameters
937
+ ----------
938
+ arr_shape : tuple of ints
939
+ Shape of the array to be filtered.
940
+ filter_shape : BackendArray
941
+ Precomputed filter to apply in the frequency domain.
942
+ arr_padded : bool, optional
943
+ Whether the input template is padded and will need to be cropped
944
+ to arr_shape prior to filter applications. Defaults to False.
945
+ axes : tuple of ints, optional
946
+ Axes to perform Fourier transform over.
947
+
948
+ Returns
949
+ -------
950
+ Callable
951
+ Filter function with parameters template, ft_temp and template_filter.
952
+ """
953
+ if filter_shape == (1,):
954
+ return conditional_execute(identity, execute_operation=True)
955
+
956
+ # Default case, all shapes are correctly matched
957
+ def _apply_filter(template, ft_temp, template_filter):
958
+ ft_temp = be.rfftn(template, out=ft_temp, s=template.shape)
959
+ ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
960
+ return be.irfftn(ft_temp, out=template, s=template.shape)
961
+
962
+ if not arr_padded:
963
+ return _apply_filter
964
+
965
+ # Array is padded but filter is w.r.t to the original template
966
+ real_subset = tuple(slice(0, x) for x in arr_shape)
967
+ _template = be.zeros(arr_shape, be._float_dtype)
968
+ _ft_temp = be.zeros(filter_shape, be._complex_dtype)
969
+
970
+ def _apply_filter_subset(template, ft_temp, template_filter):
971
+ _template[:] = template[real_subset]
972
+ template[real_subset] = _apply_filter(_template, _ft_temp, template_filter)
973
+ return template
974
+
975
+ return _apply_filter_subset
976
+
977
+
1174
978
  MATCHING_EXHAUSTIVE_REGISTER = {
1175
979
  "CC": (cc_setup, corr_scoring),
1176
980
  "LCC": (lcc_setup, corr_scoring),
@@ -1179,6 +983,6 @@ MATCHING_EXHAUSTIVE_REGISTER = {
1179
983
  "FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
1180
984
  "FLC": (flc_setup, flc_scoring),
1181
985
  "MCC": (mcc_setup, mcc_scoring),
1182
- "batchFLCSpherical": (flcSphericalMask_setup, corr_scoring2),
986
+ "batchFLCSphericalMask": (flcSphericalMask_setup, corr_scoring2),
1183
987
  "batchFLC": (flc_setup, flc_scoring2),
1184
988
  }