pytme 0.1.9__cp311-cp311-macosx_14_0_arm64.whl → 0.2.0__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. pytme-0.2.0.data/scripts/match_template.py +1019 -0
  2. pytme-0.2.0.data/scripts/postprocess.py +570 -0
  3. {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/preprocessor_gui.py +244 -60
  4. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/METADATA +3 -1
  5. pytme-0.2.0.dist-info/RECORD +72 -0
  6. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +218 -0
  8. scripts/match_template.py +459 -218
  9. pytme-0.1.9.data/scripts/match_template.py → scripts/match_template_filters.py +459 -218
  10. scripts/postprocess.py +380 -435
  11. scripts/preprocessor_gui.py +244 -60
  12. scripts/refine_matches.py +218 -0
  13. tme/__init__.py +2 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +533 -78
  16. tme/backends/cupy_backend.py +80 -15
  17. tme/backends/npfftw_backend.py +35 -6
  18. tme/backends/pytorch_backend.py +15 -7
  19. tme/density.py +173 -78
  20. tme/extensions.cpython-311-darwin.so +0 -0
  21. tme/matching_constrained.py +195 -0
  22. tme/matching_data.py +76 -33
  23. tme/matching_exhaustive.py +354 -225
  24. tme/matching_memory.py +1 -0
  25. tme/matching_optimization.py +753 -649
  26. tme/matching_utils.py +152 -8
  27. tme/orientations.py +561 -0
  28. tme/preprocessing/__init__.py +2 -0
  29. tme/preprocessing/_utils.py +176 -0
  30. tme/preprocessing/composable_filter.py +30 -0
  31. tme/preprocessing/compose.py +52 -0
  32. tme/preprocessing/frequency_filters.py +322 -0
  33. tme/preprocessing/tilt_series.py +967 -0
  34. tme/preprocessor.py +35 -25
  35. tme/structure.py +2 -37
  36. pytme-0.1.9.data/scripts/postprocess.py +0 -625
  37. pytme-0.1.9.dist-info/RECORD +0 -61
  38. {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/estimate_ram_usage.py +0 -0
  39. {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/preprocess.py +0 -0
  40. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/LICENSE +0 -0
  41. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/entry_points.txt +0 -0
  42. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/top_level.txt +0 -0
@@ -19,13 +19,12 @@ from scipy.ndimage import laplace
19
19
 
20
20
  from .analyzer import MaxScoreOverRotations
21
21
  from .matching_utils import (
22
- apply_convolution_mode,
23
22
  handle_traceback,
24
23
  split_numpy_array_slices,
25
24
  conditional_execute,
26
25
  )
27
26
  from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
28
- from .preprocessor import Preprocessor
27
+ from .preprocessing import Compose
29
28
  from .matching_data import MatchingData
30
29
  from .backends import backend
31
30
  from .types import NDArray, CallbackClass
@@ -43,6 +42,58 @@ def _run_inner(backend_name, backend_args, **kwargs):
43
42
  return scan(**kwargs)
44
43
 
45
44
 
45
+ def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
46
+ """
47
+ Standardizes the values in in template by subtracting the mean and dividing by the
48
+ standard deviation based on the elements in mask. Subsequently, the template is
49
+ multiplied by the mask.
50
+
51
+ Parameters
52
+ ----------
53
+ template : NDArray
54
+ The data array to be normalized. This array is modified in-place.
55
+ mask : NDArray
56
+ A boolean array of the same shape as `template`. True values indicate the positions in `template`
57
+ to consider for normalization.
58
+ mask_intensity : float
59
+ Mask intensity used to compute expectations.
60
+
61
+ References
62
+ ----------
63
+ .. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
64
+ and F. Förster, J. Struct. Biol. 178, 177 (2012).
65
+ .. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
66
+ R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
67
+ 13375 (2023)
68
+
69
+ Returns
70
+ -------
71
+ None
72
+ This function modifies `template` in-place and does not return any value.
73
+ """
74
+ masked_mean = backend.sum(backend.multiply(template, mask))
75
+ masked_mean = backend.divide(masked_mean, mask_intensity)
76
+ masked_std = backend.sum(backend.multiply(backend.square(template), mask))
77
+ masked_std = backend.subtract(
78
+ masked_std / mask_intensity, backend.square(masked_mean)
79
+ )
80
+ masked_std = backend.sqrt(backend.maximum(masked_std, 0))
81
+
82
+ backend.subtract(template, masked_mean, out=template)
83
+ backend.divide(template, masked_std, out=template)
84
+ backend.multiply(template, mask, out=template)
85
+ return None
86
+
87
+
88
+ def apply_filter(ft_template, template_filter):
89
+ # This is an approximation to applying the mask, irfftn, normalize, rfftn
90
+ std_before = backend.std(ft_template)
91
+ backend.multiply(ft_template, template_filter, out=ft_template)
92
+ backend.multiply(
93
+ ft_template, std_before / backend.std(ft_template), out=ft_template
94
+ )
95
+
96
+
46
97
  def cc_setup(
47
98
  rfftn: Callable,
48
99
  irfftn: Callable,
@@ -212,13 +263,12 @@ def corr_setup(
212
263
  template_mean = backend.sum(backend.multiply(template, template_mask))
213
264
  template_mean = backend.divide(template_mean, n_observations)
214
265
  template_ssd = backend.sum(
215
- backend.square(backend.multiply(
216
- backend.multiply(template, template_mean),
217
- template_mask
218
- ))
266
+ backend.square(
267
+ backend.multiply(backend.multiply(template, template_mean), template_mask)
268
+ )
219
269
  )
220
270
  template_volume = np.prod(template.shape)
221
- backend.multiply(template, template_mask, out = template)
271
+ backend.multiply(template, template_mask, out=template)
222
272
 
223
273
  # Final numerator is score - numerator2
224
274
  numerator2 = backend.multiply(target_window_sum, template_mean)
@@ -314,49 +364,6 @@ def cam_setup(**kwargs):
314
364
  return corr_setup(**kwargs)
315
365
 
316
366
 
317
- def _normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
318
- """
319
- Standardizes the values in in template by subtracting the mean and dividing by the
320
- standard deviation based on the elements in mask. Subsequently, the template is
321
- multiplied by the mask.
322
-
323
- Parameters
324
- ----------
325
- template : NDArray
326
- The data array to be normalized. This array is modified in-place.
327
- mask : NDArray
328
- A boolean array of the same shape as `template`. True values indicate the positions in `template`
329
- to consider for normalization.
330
- mask_intensity : float
331
- Mask intensity used to compute expectations.
332
-
333
- References
334
- ----------
335
- .. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
336
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
337
- .. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
338
- R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
339
- 13375 (2023)
340
-
341
- Returns
342
- -------
343
- None
344
- This function modifies `template` in-place and does not return any value.
345
- """
346
- masked_mean = backend.sum(backend.multiply(template, mask))
347
- masked_mean = backend.divide(masked_mean, mask_intensity)
348
- masked_std = backend.sum(backend.multiply(backend.square(template), mask))
349
- masked_std = backend.subtract(
350
- masked_std / mask_intensity, backend.square(masked_mean)
351
- )
352
- masked_std = backend.sqrt(backend.maximum(masked_std, 0))
353
-
354
- backend.subtract(template, masked_mean, out=template)
355
- backend.divide(template, masked_std, out=template)
356
- backend.multiply(template, mask, out=template)
357
- return None
358
-
359
-
360
367
  def flc_setup(
361
368
  rfftn: Callable,
362
369
  irfftn: Callable,
@@ -419,7 +426,7 @@ def flc_setup(
419
426
  arr=ft_target2, shared_memory_handler=shared_memory_handler
420
427
  )
421
428
 
422
- _normalize_under_mask(
429
+ normalize_under_mask(
423
430
  template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
424
431
  )
425
432
 
@@ -541,7 +548,7 @@ def flcSphericalMask_setup(
541
548
  backend.fill(temp2, 0)
542
549
  temp2[nonzero_indices] = 1 / temp[nonzero_indices]
543
550
 
544
- _normalize_under_mask(
551
+ normalize_under_mask(
545
552
  template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
546
553
  )
547
554
 
@@ -728,10 +735,6 @@ def corr_scoring(
728
735
  datatype.
729
736
  numerator2 : Tuple[type, Tuple[int], type]
730
737
  Tuple containing a pointer to the numerator2 data, its shape, and its datatype.
731
- targetshape : Tuple[int]
732
- The shape of the target.
733
- templateshape : Tuple[int]
734
- The shape of the template.
735
738
  fast_shape : Tuple[int]
736
739
  The shape for fast Fourier transform.
737
740
  fast_ft_shape : Tuple[int]
@@ -750,8 +753,6 @@ def corr_scoring(
750
753
  instantiable.
751
754
  interpolation_order : int
752
755
  The order of interpolation to be used while rotating the template.
753
- convolution_mode : str, optional
754
- Mode to use for convolution, default is "full".
755
756
  **kwargs :
756
757
  Additional arguments to be passed to the function.
757
758
 
@@ -767,37 +768,23 @@ def corr_scoring(
767
768
  :py:meth:`cam_setup`
768
769
  :py:meth:`flcSphericalMask_setup`
769
770
  """
770
- template_buffer, template_shape, template_dtype = template
771
- ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
772
- inv_denominator_buffer, inv_denominator_pointer_shape, _ = inv_denominator
773
- numerator2_buffer, numerator2_shape, _ = numerator2
774
- filter_buffer, filter_shape, filter_dtype = template_filter
775
-
771
+ callback = callback_class
776
772
  if callback_class is not None and isinstance(callback_class, type):
777
773
  callback = callback_class(**callback_class_args)
778
- elif not isinstance(callback_class, type):
779
- callback = callback_class
780
774
 
781
- # Retrieve objects from shared memory
782
- template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
783
- ft_target = backend.sharedarr_to_arr(
784
- ft_target_shape, ft_target_dtype, ft_target_buffer
785
- )
786
- inv_denominator = backend.sharedarr_to_arr(
787
- inv_denominator_pointer_shape, template_dtype, inv_denominator_buffer
788
- )
789
- numerator2 = backend.sharedarr_to_arr(
790
- numerator2_shape, template_dtype, numerator2_buffer
791
- )
792
- template_filter = backend.sharedarr_to_arr(
793
- filter_shape, filter_dtype, filter_buffer
794
- )
775
+ template_buffer, template_shape, template_dtype = template
776
+ template = backend.sharedarr_to_arr(template_buffer, template_shape, template_dtype)
777
+ ft_target = backend.sharedarr_to_arr(*ft_target)
778
+ inv_denominator = backend.sharedarr_to_arr(*inv_denominator)
779
+ numerator2 = backend.sharedarr_to_arr(*numerator2)
780
+ template_filter = backend.sharedarr_to_arr(*template_filter)
795
781
 
796
782
  norm_template, template_mask, mask_sum = False, 1, 1
797
783
  if "template_mask" in kwargs:
798
- template_mask = backend.sharedarr_to_arr(template_shape, template_dtype, kwargs["template_mask"][0])
784
+ template_mask = backend.sharedarr_to_arr(
785
+ kwargs["template_mask"][0], template_shape, template_dtype
786
+ )
799
787
  norm_template, mask_sum = True, backend.sum(template_mask)
800
- norm_template = conditional_execute(_normalize_under_mask, norm_template)
801
788
 
802
789
  arr = backend.preallocate_array(fast_shape, real_dtype)
803
790
  ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
@@ -816,17 +803,15 @@ def corr_scoring(
816
803
  norm_denominator = (backend.sum(inv_denominator) != 1) & (
817
804
  backend.size(inv_denominator) != 1
818
805
  )
819
- filter_template = backend.size(template_filter) != 0
820
806
 
807
+ norm_template = conditional_execute(normalize_under_mask, norm_template)
808
+ callback_func = conditional_execute(callback, callback_class is not None)
821
809
  norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
822
810
  norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
823
- template_filter_func = conditional_execute(backend.multiply, filter_template)
824
-
825
- axis = tuple(range(arr.ndim))
826
- fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
827
- fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
811
+ template_filter_func = conditional_execute(
812
+ apply_filter, backend.size(template_filter) != 1
813
+ )
828
814
 
829
- template_sum = backend.sum(template)
830
815
  unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
831
816
  for index in range(rotations.shape[0]):
832
817
  rotation = rotations[index]
@@ -838,34 +823,24 @@ def corr_scoring(
838
823
  use_geometric_center=False,
839
824
  order=interpolation_order,
840
825
  )
841
- rotation_norm = template_sum / backend.sum(arr)
842
- backend.multiply(arr, rotation_norm, out=arr)
826
+
843
827
  norm_template(arr[unpadded_slice], template_mask, mask_sum)
844
828
 
845
829
  rfftn(arr, ft_temp)
846
- template_filter_func(ft_temp, template_filter, out=ft_temp)
847
-
830
+ template_filter_func(ft_template=ft_temp, template_filter=template_filter)
848
831
  backend.multiply(ft_target, ft_temp, out=ft_temp)
849
832
  irfftn(ft_temp, arr)
850
833
 
851
834
  norm_func_numerator(arr, numerator2, out=arr)
852
835
  norm_func_denominator(arr, inv_denominator, out=arr)
853
836
 
854
- if fourier_shift_scores:
855
- arr = backend.roll(arr, shift=fourier_shift, axis=axis)
856
-
857
- score = apply_convolution_mode(
858
- arr, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
837
+ callback_func(
838
+ arr,
839
+ rotation_matrix=rotation,
840
+ rotation_index=index,
841
+ **callback_class_args,
859
842
  )
860
843
 
861
- if callback_class is not None:
862
- callback(
863
- score,
864
- rotation_matrix=rotation,
865
- rotation_index=index,
866
- **callback_class_args,
867
- )
868
-
869
844
  return callback
870
845
 
871
846
 
@@ -913,32 +888,15 @@ def flc_scoring(
913
888
  .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
914
889
  and F. Förster, J. Struct. Biol. 178, 177 (2012).
915
890
  """
916
- template_buffer, template_shape, template_dtype = template
917
- template_mask_buffer, *_ = template_mask
918
- filter_buffer, filter_shape, filter_dtype = template_filter
919
-
920
- ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
921
- ft_target2_buffer, *_ = ft_target2
922
-
891
+ callback = callback_class
923
892
  if callback_class is not None and isinstance(callback_class, type):
924
893
  callback = callback_class(**callback_class_args)
925
- elif not isinstance(callback_class, type):
926
- callback = callback_class
927
894
 
928
- # Retrieve objects from shared memory
929
- template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
930
- template_mask = backend.sharedarr_to_arr(
931
- template_shape, template_dtype, template_mask_buffer
932
- )
933
- ft_target = backend.sharedarr_to_arr(
934
- ft_target_shape, ft_target_dtype, ft_target_buffer
935
- )
936
- ft_target2 = backend.sharedarr_to_arr(
937
- ft_target_shape, ft_target_dtype, ft_target2_buffer
938
- )
939
- template_filter = backend.sharedarr_to_arr(
940
- filter_shape, filter_dtype, filter_buffer
941
- )
895
+ template = backend.sharedarr_to_arr(*template)
896
+ template_mask = backend.sharedarr_to_arr(*template_mask)
897
+ ft_target = backend.sharedarr_to_arr(*ft_target)
898
+ ft_target2 = backend.sharedarr_to_arr(*ft_target2)
899
+ template_filter = backend.sharedarr_to_arr(*template_filter)
942
900
 
943
901
  arr = backend.preallocate_array(fast_shape, real_dtype)
944
902
  temp = backend.preallocate_array(fast_shape, real_dtype)
@@ -957,12 +915,10 @@ def flc_scoring(
957
915
  temp_fft=ft_temp,
958
916
  )
959
917
  eps = backend.eps(real_dtype)
960
- filter_template = backend.size(template_filter) != 0
961
- template_filter_func = conditional_execute(backend.multiply, filter_template)
962
-
963
- axis = tuple(range(arr.ndim))
964
- fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
965
- fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
918
+ template_filter_func = conditional_execute(
919
+ apply_filter, backend.size(template_filter) != 1
920
+ )
921
+ callback_func = conditional_execute(callback, callback_class is not None)
966
922
 
967
923
  unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
968
924
  for index in range(rotations.shape[0]):
@@ -981,7 +937,7 @@ def flc_scoring(
981
937
  # Given the amount of FFTs, might aswell normalize properly
982
938
  n_observations = backend.sum(temp)
983
939
 
984
- _normalize_under_mask(
940
+ normalize_under_mask(
985
941
  template=arr[unpadded_slice],
986
942
  mask=temp[unpadded_slice],
987
943
  mask_intensity=n_observations,
@@ -1004,7 +960,7 @@ def flc_scoring(
1004
960
  backend.multiply(temp, n_observations, out=temp)
1005
961
 
1006
962
  rfftn(arr, ft_temp)
1007
- template_filter_func(ft_temp, template_filter, out=ft_temp)
963
+ template_filter_func(ft_template=ft_temp, template_filter=template_filter)
1008
964
  backend.multiply(ft_target, ft_temp, out=ft_temp)
1009
965
  irfftn(ft_temp, arr)
1010
966
 
@@ -1013,23 +969,161 @@ def flc_scoring(
1013
969
  backend.fill(temp2, 0)
1014
970
  temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
1015
971
 
1016
- convolution_mode = kwargs.get("convolution_mode", "full")
972
+ callback_func(
973
+ temp2,
974
+ rotation_matrix=rotation,
975
+ rotation_index=index,
976
+ **callback_class_args,
977
+ )
978
+
979
+ return callback
980
+
981
+
982
+ def flc_scoring2(
983
+ template: Tuple[type, Tuple[int], type],
984
+ template_mask: Tuple[type, Tuple[int], type],
985
+ ft_target: Tuple[type, Tuple[int], type],
986
+ ft_target2: Tuple[type, Tuple[int], type],
987
+ template_filter: Tuple[type, Tuple[int], type],
988
+ targetshape: Tuple[int],
989
+ templateshape: Tuple[int],
990
+ fast_shape: Tuple[int],
991
+ fast_ft_shape: Tuple[int],
992
+ rotations: NDArray,
993
+ real_dtype: type,
994
+ complex_dtype: type,
995
+ callback_class: CallbackClass,
996
+ callback_class_args: Dict,
997
+ interpolation_order: int,
998
+ **kwargs,
999
+ ) -> CallbackClass:
1000
+ """
1001
+ Computes a normalized cross-correlation score of a target f a template g
1002
+ and a mask m:
1003
+
1004
+ .. math::
1005
+
1006
+ \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
1007
+ {N_m * \\sqrt{
1008
+ \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
1009
+ }
1010
+
1011
+ Where:
1012
+
1013
+ .. math::
1014
+
1015
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
1016
+
1017
+ and Nm is the number of voxels within the template mask m.
1018
+
1019
+ References
1020
+ ----------
1021
+ .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
1022
+ Microsc. Microanal. 26, 2516 (2020)
1023
+ .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
1024
+ and F. Förster, J. Struct. Biol. 178, 177 (2012).
1025
+ """
1026
+ callback = callback_class
1027
+ if callback_class is not None and isinstance(callback_class, type):
1028
+ callback = callback_class(**callback_class_args)
1029
+
1030
+ # Retrieve objects from shared memory
1031
+ template = backend.sharedarr_to_arr(*template)
1032
+ template_mask = backend.sharedarr_to_arr(*template_mask)
1033
+ ft_target = backend.sharedarr_to_arr(*ft_target)
1034
+ ft_target2 = backend.sharedarr_to_arr(*ft_target2)
1035
+ template_filter = backend.sharedarr_to_arr(*template_filter)
1017
1036
 
1018
- if fourier_shift_scores:
1019
- temp2 = backend.roll(temp2, shift=fourier_shift, axis=axis)
1037
+ arr = backend.preallocate_array(fast_shape, real_dtype)
1038
+ temp = backend.preallocate_array(fast_shape, real_dtype)
1039
+ temp2 = backend.preallocate_array(fast_shape, real_dtype)
1040
+
1041
+ ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
1042
+ ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
1020
1043
 
1021
- score = apply_convolution_mode(
1022
- temp2, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
1044
+ eps = backend.eps(real_dtype)
1045
+ template_filter_func = conditional_execute(
1046
+ apply_filter, backend.size(template_filter) != 1
1047
+ )
1048
+ callback_func = conditional_execute(callback, callback_class is not None)
1049
+
1050
+ squeeze_axis = tuple(i for i, x in enumerate(template.shape) if x == 1)
1051
+ squeeze = tuple(
1052
+ slice(0, stop) if i not in squeeze_axis else 0
1053
+ for i, stop in enumerate(template.shape)
1054
+ )
1055
+ squeeze_fast = tuple(
1056
+ slice(0, stop) if i not in squeeze_axis else 0
1057
+ for i, stop in enumerate(fast_shape)
1058
+ )
1059
+ squeeze_fast_ft = tuple(
1060
+ slice(0, stop) if i not in squeeze_axis else 0
1061
+ for i, stop in enumerate(fast_ft_shape)
1062
+ )
1063
+
1064
+ rfftn, irfftn = backend.build_fft(
1065
+ fast_shape=temp[squeeze_fast].shape,
1066
+ fast_ft_shape=fast_ft_shape,
1067
+ real_dtype=real_dtype,
1068
+ complex_dtype=complex_dtype,
1069
+ fftargs=kwargs.get("fftargs", {}),
1070
+ inverse_fast_shape=fast_shape,
1071
+ temp_real=arr[squeeze_fast],
1072
+ temp_fft=ft_temp,
1073
+ )
1074
+ for index in range(rotations.shape[0]):
1075
+ rotation = rotations[index]
1076
+ backend.fill(arr, 0)
1077
+ backend.fill(temp, 0)
1078
+ backend.rotate_array(
1079
+ arr=template[squeeze],
1080
+ arr_mask=template_mask[squeeze],
1081
+ rotation_matrix=rotation,
1082
+ out=arr[squeeze],
1083
+ out_mask=temp[squeeze],
1084
+ use_geometric_center=False,
1085
+ order=interpolation_order,
1023
1086
  )
1087
+ # Given the amount of FFTs, might aswell normalize properly
1088
+ n_observations = backend.sum(temp)
1024
1089
 
1025
- if callback_class is not None:
1026
- callback(
1027
- score,
1028
- rotation_matrix=rotation,
1029
- rotation_index=index,
1030
- **callback_class_args,
1031
- )
1090
+ normalize_under_mask(
1091
+ template=arr[squeeze],
1092
+ mask=temp[squeeze],
1093
+ mask_intensity=n_observations,
1094
+ )
1095
+ rfftn(temp[squeeze_fast], ft_temp[squeeze_fast_ft])
1096
+
1097
+ backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
1098
+ irfftn(ft_denom, temp)
1099
+ backend.divide(temp, n_observations, out=temp)
1100
+ backend.square(temp, out=temp)
1101
+
1102
+ backend.multiply(ft_target2, ft_temp[squeeze_fast_ft], out=ft_denom)
1103
+ irfftn(ft_denom, temp2)
1104
+ backend.divide(temp2, n_observations, out=temp2)
1032
1105
 
1106
+ backend.subtract(temp2, temp, out=temp)
1107
+ backend.maximum(temp, 0.0, out=temp)
1108
+ backend.sqrt(temp, out=temp)
1109
+ backend.multiply(temp, n_observations, out=temp)
1110
+
1111
+ rfftn(arr[squeeze_fast], ft_temp[squeeze_fast_ft])
1112
+ template_filter_func(ft_template=ft_temp, template_filter=template_filter)
1113
+ backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
1114
+ irfftn(ft_denom, arr)
1115
+
1116
+ tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
1117
+ nonzero_indices = temp > tol
1118
+ backend.fill(temp2, 0)
1119
+ temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
1120
+
1121
+ callback_func(
1122
+ temp2,
1123
+ rotation_matrix=rotation,
1124
+ rotation_index=index,
1125
+ **callback_class_args,
1126
+ )
1033
1127
  return callback
1034
1128
 
1035
1129
 
@@ -1083,35 +1177,18 @@ def mcc_scoring(
1083
1177
  --------
1084
1178
  :py:class:`tme.matching_optimization.MaskedCrossCorrelation`
1085
1179
  """
1086
- template_buffer, template_shape, template_dtype = template
1087
- ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
1088
- ft_target2_buffer, ft_target_shape, ft_target_dtype = ft_target2
1089
- template_mask_buffer, _, _ = template
1090
- ft_target_mask_buffer, _, _ = ft_target
1091
- filter_buffer, filter_shape, filter_dtype = template_filter
1092
-
1180
+ callback = callback_class
1093
1181
  if callback_class is not None and isinstance(callback_class, type):
1094
1182
  callback = callback_class(**callback_class_args)
1095
- elif not isinstance(callback_class, type):
1096
- callback = callback_class
1097
1183
 
1098
1184
  # Retrieve objects from shared memory
1099
- template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
1100
- target_ft = backend.sharedarr_to_arr(
1101
- ft_target_shape, ft_target_dtype, ft_target_buffer
1102
- )
1103
- target_ft2 = backend.sharedarr_to_arr(
1104
- ft_target_shape, ft_target_dtype, ft_target2_buffer
1105
- )
1106
- template_mask = backend.sharedarr_to_arr(
1107
- template_shape, template_dtype, template_mask_buffer
1108
- )
1109
- target_mask_ft = backend.sharedarr_to_arr(
1110
- ft_target_shape, ft_target_dtype, ft_target_mask_buffer
1111
- )
1112
- template_filter = backend.sharedarr_to_arr(
1113
- filter_shape, filter_dtype, filter_buffer
1114
- )
1185
+ template_buffer, template_shape, template_dtype = template
1186
+ template = backend.sharedarr_to_arr(*template)
1187
+ target_ft = backend.sharedarr_to_arr(*ft_target)
1188
+ target_ft2 = backend.sharedarr_to_arr(*ft_target2)
1189
+ template_mask = backend.sharedarr_to_arr(*template_mask)
1190
+ target_mask_ft = backend.sharedarr_to_arr(*ft_target_mask)
1191
+ template_filter = backend.sharedarr_to_arr(*template_filter)
1115
1192
 
1116
1193
  axes = tuple(range(template.ndim))
1117
1194
  eps = backend.eps(real_dtype)
@@ -1136,14 +1213,10 @@ def mcc_scoring(
1136
1213
  temp_fft=temp_ft,
1137
1214
  )
1138
1215
 
1139
- filter_template = backend.size(template_filter) != 0
1140
- template_filter_func = conditional_execute(backend.multiply, filter_template)
1141
-
1142
- axis = tuple(range(template.ndim))
1143
- fourier_shift = callback_class_args.get(
1144
- "fourier_shift", backend.zeros(template.ndim)
1216
+ template_filter_func = conditional_execute(
1217
+ apply_filter, backend.size(template_filter) != 1
1145
1218
  )
1146
- fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
1219
+ callback_func = conditional_execute(callback, callback_class is not None)
1147
1220
 
1148
1221
  # Calculate scores across all rotations
1149
1222
  for index in range(rotations.shape[0]):
@@ -1165,7 +1238,7 @@ def mcc_scoring(
1165
1238
 
1166
1239
  # template_rot_ft
1167
1240
  rfftn(template_rot, temp_ft)
1168
- template_filter_func(temp_ft, template_filter, out=temp_ft)
1241
+ template_filter_func(ft_template=temp_ft, template_filter=template_filter)
1169
1242
  irfftn(target_mask_ft * temp_ft, temp2)
1170
1243
  irfftn(target_ft * temp_ft, numerator)
1171
1244
 
@@ -1221,25 +1294,68 @@ def mcc_scoring(
1221
1294
  mask_overlap, axis=axes, keepdims=True
1222
1295
  )
1223
1296
  temp[mask_overlap < number_px_threshold] = 0.0
1224
- convolution_mode = kwargs.get("convolution_mode", "full")
1225
1297
 
1226
- if fourier_shift_scores:
1227
- temp = backend.roll(temp, shift=fourier_shift, axis=axis)
1228
-
1229
- score = apply_convolution_mode(
1230
- temp, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
1298
+ callback_func(
1299
+ temp,
1300
+ rotation_matrix=rotation,
1301
+ rotation_index=index,
1302
+ **callback_class_args,
1231
1303
  )
1232
- if callback_class is not None:
1233
- callback(
1234
- score,
1235
- rotation_matrix=rotation,
1236
- rotation_index=index,
1237
- **callback_class_args,
1238
- )
1239
1304
 
1240
1305
  return callback
1241
1306
 
1242
1307
 
1308
+ def _setup_template_filter_apply_target_filter(
1309
+ matching_data: MatchingData,
1310
+ rfftn: Callable,
1311
+ irfftn: Callable,
1312
+ fast_shape: Tuple[int],
1313
+ fast_ft_shape: Tuple[int],
1314
+ ):
1315
+ filter_template = isinstance(matching_data.template_filter, Compose)
1316
+ filter_target = isinstance(matching_data.target_filter, Compose)
1317
+
1318
+ template_filter = backend.full(
1319
+ shape=(1,), fill_value=1, dtype=backend._default_dtype
1320
+ )
1321
+
1322
+ if not filter_template and not filter_target:
1323
+ return template_filter
1324
+
1325
+ target_temp = backend.astype(
1326
+ backend.topleft_pad(matching_data.target, fast_shape), backend._default_dtype
1327
+ )
1328
+ target_temp_ft = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
1329
+ rfftn(target_temp, target_temp_ft)
1330
+ if isinstance(matching_data.template_filter, Compose):
1331
+ template_filter = matching_data.template_filter(
1332
+ shape=fast_shape,
1333
+ return_real_fourier=True,
1334
+ shape_is_real_fourier=False,
1335
+ data_rfft=target_temp_ft,
1336
+ )
1337
+ template_filter = template_filter["data"]
1338
+ template_filter[tuple(0 for _ in range(template_filter.ndim))] = 0
1339
+
1340
+ if isinstance(matching_data.target_filter, Compose):
1341
+ target_filter = matching_data.target_filter(
1342
+ shape=fast_shape,
1343
+ return_real_fourier=True,
1344
+ shape_is_real_fourier=False,
1345
+ data_rfft=target_temp_ft,
1346
+ weight_type=None,
1347
+ )
1348
+ target_filter = target_filter["data"]
1349
+ backend.multiply(target_temp_ft, target_filter, out=target_temp_ft)
1350
+
1351
+ irfftn(target_temp_ft, target_temp)
1352
+ matching_data._target = backend.topleft_pad(
1353
+ target_temp, matching_data.target.shape
1354
+ )
1355
+
1356
+ return template_filter
1357
+
1358
+
1243
1359
  def device_memory_handler(func: Callable):
1244
1360
  """Decorator function providing SharedMemory Handler."""
1245
1361
 
@@ -1310,18 +1426,17 @@ def scan(
1310
1426
  Tuple
1311
1427
  The merged results from callback_class if provided otherwise None.
1312
1428
  """
1429
+ matching_data.to_backend()
1313
1430
  shape_diff = backend.subtract(
1314
- matching_data._target.shape, matching_data._template.shape
1431
+ matching_data._output_target_shape, matching_data._output_template_shape
1315
1432
  )
1433
+ shape_diff = backend.multiply(shape_diff, ~matching_data._batch_mask)
1316
1434
  if backend.sum(shape_diff < 0) and not pad_fourier:
1317
1435
  warnings.warn(
1318
1436
  "Target is larger than template and Fourier padding is turned off."
1319
- " This can lead to shifted results. You can swap template and target, or"
1320
- " zero-pad the target."
1437
+ " This can lead to shifted results. You can swap template and target, "
1438
+ " zero-pad the target or turn off template centering."
1321
1439
  )
1322
-
1323
- matching_data.to_backend()
1324
-
1325
1440
  fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
1326
1441
  pad_fourier=pad_fourier
1327
1442
  )
@@ -1334,6 +1449,15 @@ def scan(
1334
1449
  complex_dtype=matching_data._complex_dtype,
1335
1450
  fftargs=fftargs,
1336
1451
  )
1452
+
1453
+ template_filter = _setup_template_filter_apply_target_filter(
1454
+ matching_data=matching_data,
1455
+ rfftn=rfftn,
1456
+ irfftn=irfftn,
1457
+ fast_shape=fast_shape,
1458
+ fast_ft_shape=fast_ft_shape,
1459
+ )
1460
+
1337
1461
  setup = matching_setup(
1338
1462
  rfftn=rfftn,
1339
1463
  irfftn=irfftn,
@@ -1351,22 +1475,7 @@ def scan(
1351
1475
  )
1352
1476
  rfftn, irfftn = None, None
1353
1477
 
1354
- template_filter, preprocessor = None, Preprocessor()
1355
- for method, parameters in matching_data.template_filter.items():
1356
- parameters["shape"] = fast_shape
1357
- parameters["omit_negative_frequencies"] = True
1358
- out = preprocessor.apply_method(method=method, parameters=parameters)
1359
- if template_filter is None:
1360
- template_filter = out
1361
- np.multiply(template_filter, out, out=template_filter)
1362
-
1363
- if template_filter is None:
1364
- template_filter = backend.full(
1365
- shape=(1,), fill_value=1, dtype=backend._default_dtype
1366
- )
1367
- else:
1368
- template_filter = backend.to_backend_array(template_filter)
1369
-
1478
+ template_filter = backend.to_backend_array(template_filter)
1370
1479
  template_filter = backend.astype(template_filter, backend._default_dtype)
1371
1480
  template_filter_buffer = backend.arr_to_sharedarr(
1372
1481
  arr=template_filter,
@@ -1388,14 +1497,20 @@ def scan(
1388
1497
  callback_class = setup.pop("callback_class", callback_class)
1389
1498
  callback_class_args = setup.pop("callback_class_args", callback_class_args)
1390
1499
  callback_classes = [callback_class for _ in range(n_callback_classes)]
1500
+
1501
+ convolution_mode = "same"
1502
+ if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
1503
+ convolution_mode = "valid"
1504
+
1505
+ callback_class_args["fourier_shift"] = fourier_shift
1506
+ callback_class_args["convolution_mode"] = convolution_mode
1507
+ callback_class_args["targetshape"] = setup["targetshape"]
1508
+ callback_class_args["templateshape"] = setup["templateshape"]
1509
+
1391
1510
  if callback_class == MaxScoreOverRotations:
1392
- score_space_shape = backend.subtract(
1393
- matching_data.target.shape,
1394
- matching_data._target_pad,
1395
- )
1396
1511
  callback_classes = [
1397
1512
  class_name(
1398
- score_space_shape=score_space_shape,
1513
+ score_space_shape=fast_shape,
1399
1514
  score_space_dtype=matching_data._default_dtype,
1400
1515
  shared_memory_handler=kwargs.get("shared_memory_handler", None),
1401
1516
  rotation_space_dtype=backend._default_dtype_int,
@@ -1435,9 +1550,20 @@ def scan(
1435
1550
  for index, rotation in enumerate(rotation_list)
1436
1551
  )
1437
1552
 
1553
+ callbacks = callbacks[0:n_callback_classes]
1438
1554
  callbacks = [
1439
- tuple(callback)
1440
- for callback in callbacks[0:n_callback_classes]
1555
+ tuple(
1556
+ callback._postprocess(
1557
+ fourier_shift=fourier_shift,
1558
+ convolution_mode=convolution_mode,
1559
+ targetshape=setup["targetshape"],
1560
+ templateshape=setup["templateshape"],
1561
+ shared_memory_handler=kwargs.get("shared_memory_handler", None),
1562
+ )
1563
+ )
1564
+ if hasattr(callback, "_postprocess")
1565
+ else tuple(callback)
1566
+ for callback in callbacks
1441
1567
  if callback is not None
1442
1568
  ]
1443
1569
  backend.free_cache()
@@ -1549,11 +1675,13 @@ def scan_subsets(
1549
1675
  matching_data._target, matching_data._template = None, None
1550
1676
  matching_data._target_mask, matching_data._template_mask = None, None
1551
1677
 
1678
+ candidates = None
1552
1679
  if callback_class is not None:
1553
1680
  candidates = callback_class.merge(
1554
1681
  results, **callback_class_args, inner_merge=False
1555
1682
  )
1556
- return candidates
1683
+
1684
+ return candidates
1557
1685
 
1558
1686
 
1559
1687
  MATCHING_EXHAUSTIVE_REGISTER = {
@@ -1563,6 +1691,7 @@ MATCHING_EXHAUSTIVE_REGISTER = {
1563
1691
  "CAM": (cam_setup, corr_scoring),
1564
1692
  "FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
1565
1693
  "FLC": (flc_setup, flc_scoring),
1694
+ "FLC2": (flc_setup, flc_scoring2),
1566
1695
  "MCC": (mcc_setup, mcc_scoring),
1567
1696
  }
1568
1697