pytme 0.1.9__cp311-cp311-macosx_14_0_arm64.whl → 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {pytme-0.1.9.data → pytme-0.2.0b0.data}/scripts/match_template.py +148 -126
  2. pytme-0.2.0b0.data/scripts/postprocess.py +570 -0
  3. {pytme-0.1.9.data → pytme-0.2.0b0.data}/scripts/preprocessor_gui.py +244 -60
  4. {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/METADATA +3 -1
  5. pytme-0.2.0b0.dist-info/RECORD +66 -0
  6. {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +218 -0
  8. scripts/match_template.py +148 -126
  9. scripts/match_template_filters.py +852 -0
  10. scripts/postprocess.py +380 -435
  11. scripts/preprocessor_gui.py +244 -60
  12. scripts/refine_matches.py +218 -0
  13. tme/__init__.py +2 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +545 -78
  16. tme/backends/cupy_backend.py +80 -15
  17. tme/backends/npfftw_backend.py +33 -2
  18. tme/backends/pytorch_backend.py +15 -7
  19. tme/density.py +156 -63
  20. tme/extensions.cpython-311-darwin.so +0 -0
  21. tme/matching_constrained.py +195 -0
  22. tme/matching_data.py +74 -33
  23. tme/matching_exhaustive.py +351 -208
  24. tme/matching_memory.py +1 -0
  25. tme/matching_optimization.py +728 -651
  26. tme/matching_utils.py +152 -8
  27. tme/orientations.py +561 -0
  28. tme/preprocessor.py +21 -18
  29. tme/structure.py +2 -37
  30. pytme-0.1.9.data/scripts/postprocess.py +0 -625
  31. pytme-0.1.9.dist-info/RECORD +0 -61
  32. {pytme-0.1.9.data → pytme-0.2.0b0.data}/scripts/estimate_ram_usage.py +0 -0
  33. {pytme-0.1.9.data → pytme-0.2.0b0.data}/scripts/preprocess.py +0 -0
  34. {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/LICENSE +0 -0
  35. {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/entry_points.txt +0 -0
  36. {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ from .matching_utils import (
25
25
  conditional_execute,
26
26
  )
27
27
  from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
28
+ # from .preprocessing import Compose
28
29
  from .preprocessor import Preprocessor
29
30
  from .matching_data import MatchingData
30
31
  from .backends import backend
@@ -43,6 +44,58 @@ def _run_inner(backend_name, backend_args, **kwargs):
43
44
  return scan(**kwargs)
44
45
 
45
46
 
47
+ def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
48
+ """
49
+ Standardizes the values in in template by subtracting the mean and dividing by the
50
+ standard deviation based on the elements in mask. Subsequently, the template is
51
+ multiplied by the mask.
52
+
53
+ Parameters
54
+ ----------
55
+ template : NDArray
56
+ The data array to be normalized. This array is modified in-place.
57
+ mask : NDArray
58
+ A boolean array of the same shape as `template`. True values indicate the positions in `template`
59
+ to consider for normalization.
60
+ mask_intensity : float
61
+ Mask intensity used to compute expectations.
62
+
63
+ References
64
+ ----------
65
+ .. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
66
+ and F. Förster, J. Struct. Biol. 178, 177 (2012).
67
+ .. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
68
+ R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
69
+ 13375 (2023)
70
+
71
+ Returns
72
+ -------
73
+ None
74
+ This function modifies `template` in-place and does not return any value.
75
+ """
76
+ masked_mean = backend.sum(backend.multiply(template, mask))
77
+ masked_mean = backend.divide(masked_mean, mask_intensity)
78
+ masked_std = backend.sum(backend.multiply(backend.square(template), mask))
79
+ masked_std = backend.subtract(
80
+ masked_std / mask_intensity, backend.square(masked_mean)
81
+ )
82
+ masked_std = backend.sqrt(backend.maximum(masked_std, 0))
83
+
84
+ backend.subtract(template, masked_mean, out=template)
85
+ backend.divide(template, masked_std, out=template)
86
+ backend.multiply(template, mask, out=template)
87
+ return None
88
+
89
+
90
+ def apply_filter(ft_template, template_filter):
91
+ # This is an approximation to applying the mask, irfftn, normalize, rfftn
92
+ std_before = backend.std(ft_template)
93
+ backend.multiply(ft_template, template_filter, out=ft_template)
94
+ backend.multiply(
95
+ ft_template, std_before / backend.std(ft_template), out=ft_template
96
+ )
97
+
98
+
46
99
  def cc_setup(
47
100
  rfftn: Callable,
48
101
  irfftn: Callable,
@@ -212,13 +265,12 @@ def corr_setup(
212
265
  template_mean = backend.sum(backend.multiply(template, template_mask))
213
266
  template_mean = backend.divide(template_mean, n_observations)
214
267
  template_ssd = backend.sum(
215
- backend.square(backend.multiply(
216
- backend.multiply(template, template_mean),
217
- template_mask
218
- ))
268
+ backend.square(
269
+ backend.multiply(backend.multiply(template, template_mean), template_mask)
270
+ )
219
271
  )
220
272
  template_volume = np.prod(template.shape)
221
- backend.multiply(template, template_mask, out = template)
273
+ backend.multiply(template, template_mask, out=template)
222
274
 
223
275
  # Final numerator is score - numerator2
224
276
  numerator2 = backend.multiply(target_window_sum, template_mean)
@@ -314,49 +366,6 @@ def cam_setup(**kwargs):
314
366
  return corr_setup(**kwargs)
315
367
 
316
368
 
317
- def _normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
318
- """
319
- Standardizes the values in in template by subtracting the mean and dividing by the
320
- standard deviation based on the elements in mask. Subsequently, the template is
321
- multiplied by the mask.
322
-
323
- Parameters
324
- ----------
325
- template : NDArray
326
- The data array to be normalized. This array is modified in-place.
327
- mask : NDArray
328
- A boolean array of the same shape as `template`. True values indicate the positions in `template`
329
- to consider for normalization.
330
- mask_intensity : float
331
- Mask intensity used to compute expectations.
332
-
333
- References
334
- ----------
335
- .. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
336
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
337
- .. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
338
- R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
339
- 13375 (2023)
340
-
341
- Returns
342
- -------
343
- None
344
- This function modifies `template` in-place and does not return any value.
345
- """
346
- masked_mean = backend.sum(backend.multiply(template, mask))
347
- masked_mean = backend.divide(masked_mean, mask_intensity)
348
- masked_std = backend.sum(backend.multiply(backend.square(template), mask))
349
- masked_std = backend.subtract(
350
- masked_std / mask_intensity, backend.square(masked_mean)
351
- )
352
- masked_std = backend.sqrt(backend.maximum(masked_std, 0))
353
-
354
- backend.subtract(template, masked_mean, out=template)
355
- backend.divide(template, masked_std, out=template)
356
- backend.multiply(template, mask, out=template)
357
- return None
358
-
359
-
360
369
  def flc_setup(
361
370
  rfftn: Callable,
362
371
  irfftn: Callable,
@@ -419,7 +428,7 @@ def flc_setup(
419
428
  arr=ft_target2, shared_memory_handler=shared_memory_handler
420
429
  )
421
430
 
422
- _normalize_under_mask(
431
+ normalize_under_mask(
423
432
  template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
424
433
  )
425
434
 
@@ -541,7 +550,7 @@ def flcSphericalMask_setup(
541
550
  backend.fill(temp2, 0)
542
551
  temp2[nonzero_indices] = 1 / temp[nonzero_indices]
543
552
 
544
- _normalize_under_mask(
553
+ normalize_under_mask(
545
554
  template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
546
555
  )
547
556
 
@@ -728,10 +737,6 @@ def corr_scoring(
728
737
  datatype.
729
738
  numerator2 : Tuple[type, Tuple[int], type]
730
739
  Tuple containing a pointer to the numerator2 data, its shape, and its datatype.
731
- targetshape : Tuple[int]
732
- The shape of the target.
733
- templateshape : Tuple[int]
734
- The shape of the template.
735
740
  fast_shape : Tuple[int]
736
741
  The shape for fast Fourier transform.
737
742
  fast_ft_shape : Tuple[int]
@@ -750,8 +755,6 @@ def corr_scoring(
750
755
  instantiable.
751
756
  interpolation_order : int
752
757
  The order of interpolation to be used while rotating the template.
753
- convolution_mode : str, optional
754
- Mode to use for convolution, default is "full".
755
758
  **kwargs :
756
759
  Additional arguments to be passed to the function.
757
760
 
@@ -767,37 +770,23 @@ def corr_scoring(
767
770
  :py:meth:`cam_setup`
768
771
  :py:meth:`flcSphericalMask_setup`
769
772
  """
770
- template_buffer, template_shape, template_dtype = template
771
- ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
772
- inv_denominator_buffer, inv_denominator_pointer_shape, _ = inv_denominator
773
- numerator2_buffer, numerator2_shape, _ = numerator2
774
- filter_buffer, filter_shape, filter_dtype = template_filter
775
-
773
+ callback = callback_class
776
774
  if callback_class is not None and isinstance(callback_class, type):
777
775
  callback = callback_class(**callback_class_args)
778
- elif not isinstance(callback_class, type):
779
- callback = callback_class
780
776
 
781
- # Retrieve objects from shared memory
782
- template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
783
- ft_target = backend.sharedarr_to_arr(
784
- ft_target_shape, ft_target_dtype, ft_target_buffer
785
- )
786
- inv_denominator = backend.sharedarr_to_arr(
787
- inv_denominator_pointer_shape, template_dtype, inv_denominator_buffer
788
- )
789
- numerator2 = backend.sharedarr_to_arr(
790
- numerator2_shape, template_dtype, numerator2_buffer
791
- )
792
- template_filter = backend.sharedarr_to_arr(
793
- filter_shape, filter_dtype, filter_buffer
794
- )
777
+ template_buffer, template_shape, template_dtype = template
778
+ template = backend.sharedarr_to_arr(template_buffer, template_shape, template_dtype)
779
+ ft_target = backend.sharedarr_to_arr(*ft_target)
780
+ inv_denominator = backend.sharedarr_to_arr(*inv_denominator)
781
+ numerator2 = backend.sharedarr_to_arr(*numerator2)
782
+ template_filter = backend.sharedarr_to_arr(*template_filter)
795
783
 
796
784
  norm_template, template_mask, mask_sum = False, 1, 1
797
785
  if "template_mask" in kwargs:
798
- template_mask = backend.sharedarr_to_arr(template_shape, template_dtype, kwargs["template_mask"][0])
786
+ template_mask = backend.sharedarr_to_arr(
787
+ kwargs["template_mask"][0], template_shape, template_dtype
788
+ )
799
789
  norm_template, mask_sum = True, backend.sum(template_mask)
800
- norm_template = conditional_execute(_normalize_under_mask, norm_template)
801
790
 
802
791
  arr = backend.preallocate_array(fast_shape, real_dtype)
803
792
  ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
@@ -816,17 +805,15 @@ def corr_scoring(
816
805
  norm_denominator = (backend.sum(inv_denominator) != 1) & (
817
806
  backend.size(inv_denominator) != 1
818
807
  )
819
- filter_template = backend.size(template_filter) != 0
820
808
 
809
+ norm_template = conditional_execute(normalize_under_mask, norm_template)
810
+ callback_func = conditional_execute(callback, callback_class is not None)
821
811
  norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
822
812
  norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
823
- template_filter_func = conditional_execute(backend.multiply, filter_template)
824
-
825
- axis = tuple(range(arr.ndim))
826
- fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
827
- fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
813
+ template_filter_func = conditional_execute(
814
+ apply_filter, backend.size(template_filter) != 1
815
+ )
828
816
 
829
- template_sum = backend.sum(template)
830
817
  unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
831
818
  for index in range(rotations.shape[0]):
832
819
  rotation = rotations[index]
@@ -838,34 +825,24 @@ def corr_scoring(
838
825
  use_geometric_center=False,
839
826
  order=interpolation_order,
840
827
  )
841
- rotation_norm = template_sum / backend.sum(arr)
842
- backend.multiply(arr, rotation_norm, out=arr)
828
+
843
829
  norm_template(arr[unpadded_slice], template_mask, mask_sum)
844
830
 
845
831
  rfftn(arr, ft_temp)
846
- template_filter_func(ft_temp, template_filter, out=ft_temp)
847
-
832
+ template_filter_func(ft_template=ft_temp, template_filter=template_filter)
848
833
  backend.multiply(ft_target, ft_temp, out=ft_temp)
849
834
  irfftn(ft_temp, arr)
850
835
 
851
836
  norm_func_numerator(arr, numerator2, out=arr)
852
837
  norm_func_denominator(arr, inv_denominator, out=arr)
853
838
 
854
- if fourier_shift_scores:
855
- arr = backend.roll(arr, shift=fourier_shift, axis=axis)
856
-
857
- score = apply_convolution_mode(
858
- arr, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
839
+ callback_func(
840
+ arr,
841
+ rotation_matrix=rotation,
842
+ rotation_index=index,
843
+ **callback_class_args,
859
844
  )
860
845
 
861
- if callback_class is not None:
862
- callback(
863
- score,
864
- rotation_matrix=rotation,
865
- rotation_index=index,
866
- **callback_class_args,
867
- )
868
-
869
846
  return callback
870
847
 
871
848
 
@@ -913,32 +890,15 @@ def flc_scoring(
913
890
  .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
914
891
  and F. Förster, J. Struct. Biol. 178, 177 (2012).
915
892
  """
916
- template_buffer, template_shape, template_dtype = template
917
- template_mask_buffer, *_ = template_mask
918
- filter_buffer, filter_shape, filter_dtype = template_filter
919
-
920
- ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
921
- ft_target2_buffer, *_ = ft_target2
922
-
893
+ callback = callback_class
923
894
  if callback_class is not None and isinstance(callback_class, type):
924
895
  callback = callback_class(**callback_class_args)
925
- elif not isinstance(callback_class, type):
926
- callback = callback_class
927
896
 
928
- # Retrieve objects from shared memory
929
- template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
930
- template_mask = backend.sharedarr_to_arr(
931
- template_shape, template_dtype, template_mask_buffer
932
- )
933
- ft_target = backend.sharedarr_to_arr(
934
- ft_target_shape, ft_target_dtype, ft_target_buffer
935
- )
936
- ft_target2 = backend.sharedarr_to_arr(
937
- ft_target_shape, ft_target_dtype, ft_target2_buffer
938
- )
939
- template_filter = backend.sharedarr_to_arr(
940
- filter_shape, filter_dtype, filter_buffer
941
- )
897
+ template = backend.sharedarr_to_arr(*template)
898
+ template_mask = backend.sharedarr_to_arr(*template_mask)
899
+ ft_target = backend.sharedarr_to_arr(*ft_target)
900
+ ft_target2 = backend.sharedarr_to_arr(*ft_target2)
901
+ template_filter = backend.sharedarr_to_arr(*template_filter)
942
902
 
943
903
  arr = backend.preallocate_array(fast_shape, real_dtype)
944
904
  temp = backend.preallocate_array(fast_shape, real_dtype)
@@ -957,12 +917,10 @@ def flc_scoring(
957
917
  temp_fft=ft_temp,
958
918
  )
959
919
  eps = backend.eps(real_dtype)
960
- filter_template = backend.size(template_filter) != 0
961
- template_filter_func = conditional_execute(backend.multiply, filter_template)
962
-
963
- axis = tuple(range(arr.ndim))
964
- fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
965
- fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
920
+ template_filter_func = conditional_execute(
921
+ apply_filter, backend.size(template_filter) != 1
922
+ )
923
+ callback_func = conditional_execute(callback, callback_class is not None)
966
924
 
967
925
  unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
968
926
  for index in range(rotations.shape[0]):
@@ -981,7 +939,7 @@ def flc_scoring(
981
939
  # Given the amount of FFTs, might aswell normalize properly
982
940
  n_observations = backend.sum(temp)
983
941
 
984
- _normalize_under_mask(
942
+ normalize_under_mask(
985
943
  template=arr[unpadded_slice],
986
944
  mask=temp[unpadded_slice],
987
945
  mask_intensity=n_observations,
@@ -1004,7 +962,7 @@ def flc_scoring(
1004
962
  backend.multiply(temp, n_observations, out=temp)
1005
963
 
1006
964
  rfftn(arr, ft_temp)
1007
- template_filter_func(ft_temp, template_filter, out=ft_temp)
965
+ template_filter_func(ft_template=ft_temp, template_filter=template_filter)
1008
966
  backend.multiply(ft_target, ft_temp, out=ft_temp)
1009
967
  irfftn(ft_temp, arr)
1010
968
 
@@ -1013,23 +971,161 @@ def flc_scoring(
1013
971
  backend.fill(temp2, 0)
1014
972
  temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
1015
973
 
1016
- convolution_mode = kwargs.get("convolution_mode", "full")
974
+ callback_func(
975
+ temp2,
976
+ rotation_matrix=rotation,
977
+ rotation_index=index,
978
+ **callback_class_args,
979
+ )
980
+
981
+ return callback
982
+
983
+
984
+ def flc_scoring2(
985
+ template: Tuple[type, Tuple[int], type],
986
+ template_mask: Tuple[type, Tuple[int], type],
987
+ ft_target: Tuple[type, Tuple[int], type],
988
+ ft_target2: Tuple[type, Tuple[int], type],
989
+ template_filter: Tuple[type, Tuple[int], type],
990
+ targetshape: Tuple[int],
991
+ templateshape: Tuple[int],
992
+ fast_shape: Tuple[int],
993
+ fast_ft_shape: Tuple[int],
994
+ rotations: NDArray,
995
+ real_dtype: type,
996
+ complex_dtype: type,
997
+ callback_class: CallbackClass,
998
+ callback_class_args: Dict,
999
+ interpolation_order: int,
1000
+ **kwargs,
1001
+ ) -> CallbackClass:
1002
+ """
1003
+ Computes a normalized cross-correlation score of a target f a template g
1004
+ and a mask m:
1005
+
1006
+ .. math::
1007
+
1008
+ \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
1009
+ {N_m * \\sqrt{
1010
+ \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
1011
+ }
1017
1012
 
1018
- if fourier_shift_scores:
1019
- temp2 = backend.roll(temp2, shift=fourier_shift, axis=axis)
1013
+ Where:
1014
+
1015
+ .. math::
1020
1016
 
1021
- score = apply_convolution_mode(
1022
- temp2, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
1017
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
1018
+
1019
+ and Nm is the number of voxels within the template mask m.
1020
+
1021
+ References
1022
+ ----------
1023
+ .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
1024
+ Microsc. Microanal. 26, 2516 (2020)
1025
+ .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
1026
+ and F. Förster, J. Struct. Biol. 178, 177 (2012).
1027
+ """
1028
+ callback = callback_class
1029
+ if callback_class is not None and isinstance(callback_class, type):
1030
+ callback = callback_class(**callback_class_args)
1031
+
1032
+ # Retrieve objects from shared memory
1033
+ template = backend.sharedarr_to_arr(*template)
1034
+ template_mask = backend.sharedarr_to_arr(*template_mask)
1035
+ ft_target = backend.sharedarr_to_arr(*ft_target)
1036
+ ft_target2 = backend.sharedarr_to_arr(*ft_target2)
1037
+ template_filter = backend.sharedarr_to_arr(*template_filter)
1038
+
1039
+ arr = backend.preallocate_array(fast_shape, real_dtype)
1040
+ temp = backend.preallocate_array(fast_shape, real_dtype)
1041
+ temp2 = backend.preallocate_array(fast_shape, real_dtype)
1042
+
1043
+ ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
1044
+ ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
1045
+
1046
+ eps = backend.eps(real_dtype)
1047
+ template_filter_func = conditional_execute(
1048
+ apply_filter, backend.size(template_filter) != 1
1049
+ )
1050
+ callback_func = conditional_execute(callback, callback_class is not None)
1051
+
1052
+ squeeze_axis = tuple(i for i, x in enumerate(template.shape) if x == 1)
1053
+ squeeze = tuple(
1054
+ slice(0, stop) if i not in squeeze_axis else 0
1055
+ for i, stop in enumerate(template.shape)
1056
+ )
1057
+ squeeze_fast = tuple(
1058
+ slice(0, stop) if i not in squeeze_axis else 0
1059
+ for i, stop in enumerate(fast_shape)
1060
+ )
1061
+ squeeze_fast_ft = tuple(
1062
+ slice(0, stop) if i not in squeeze_axis else 0
1063
+ for i, stop in enumerate(fast_ft_shape)
1064
+ )
1065
+
1066
+ rfftn, irfftn = backend.build_fft(
1067
+ fast_shape=temp[squeeze_fast].shape,
1068
+ fast_ft_shape=fast_ft_shape,
1069
+ real_dtype=real_dtype,
1070
+ complex_dtype=complex_dtype,
1071
+ fftargs=kwargs.get("fftargs", {}),
1072
+ inverse_fast_shape=fast_shape,
1073
+ temp_real=arr[squeeze_fast],
1074
+ temp_fft=ft_temp,
1075
+ )
1076
+ for index in range(rotations.shape[0]):
1077
+ rotation = rotations[index]
1078
+ backend.fill(arr, 0)
1079
+ backend.fill(temp, 0)
1080
+ backend.rotate_array(
1081
+ arr=template[squeeze],
1082
+ arr_mask=template_mask[squeeze],
1083
+ rotation_matrix=rotation,
1084
+ out=arr[squeeze],
1085
+ out_mask=temp[squeeze],
1086
+ use_geometric_center=False,
1087
+ order=interpolation_order,
1023
1088
  )
1089
+ # Given the amount of FFTs, might aswell normalize properly
1090
+ n_observations = backend.sum(temp)
1024
1091
 
1025
- if callback_class is not None:
1026
- callback(
1027
- score,
1028
- rotation_matrix=rotation,
1029
- rotation_index=index,
1030
- **callback_class_args,
1031
- )
1092
+ normalize_under_mask(
1093
+ template=arr[squeeze],
1094
+ mask=temp[squeeze],
1095
+ mask_intensity=n_observations,
1096
+ )
1097
+ rfftn(temp[squeeze_fast], ft_temp[squeeze_fast_ft])
1032
1098
 
1099
+ backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
1100
+ irfftn(ft_denom, temp)
1101
+ backend.divide(temp, n_observations, out=temp)
1102
+ backend.square(temp, out=temp)
1103
+
1104
+ backend.multiply(ft_target2, ft_temp[squeeze_fast_ft], out=ft_denom)
1105
+ irfftn(ft_denom, temp2)
1106
+ backend.divide(temp2, n_observations, out=temp2)
1107
+
1108
+ backend.subtract(temp2, temp, out=temp)
1109
+ backend.maximum(temp, 0.0, out=temp)
1110
+ backend.sqrt(temp, out=temp)
1111
+ backend.multiply(temp, n_observations, out=temp)
1112
+
1113
+ rfftn(arr[squeeze_fast], ft_temp[squeeze_fast_ft])
1114
+ template_filter_func(ft_template=ft_temp, template_filter=template_filter)
1115
+ backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
1116
+ irfftn(ft_denom, arr)
1117
+
1118
+ tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
1119
+ nonzero_indices = temp > tol
1120
+ backend.fill(temp2, 0)
1121
+ temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
1122
+
1123
+ callback_func(
1124
+ temp2,
1125
+ rotation_matrix=rotation,
1126
+ rotation_index=index,
1127
+ **callback_class_args,
1128
+ )
1033
1129
  return callback
1034
1130
 
1035
1131
 
@@ -1083,35 +1179,18 @@ def mcc_scoring(
1083
1179
  --------
1084
1180
  :py:class:`tme.matching_optimization.MaskedCrossCorrelation`
1085
1181
  """
1086
- template_buffer, template_shape, template_dtype = template
1087
- ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
1088
- ft_target2_buffer, ft_target_shape, ft_target_dtype = ft_target2
1089
- template_mask_buffer, _, _ = template
1090
- ft_target_mask_buffer, _, _ = ft_target
1091
- filter_buffer, filter_shape, filter_dtype = template_filter
1092
-
1182
+ callback = callback_class
1093
1183
  if callback_class is not None and isinstance(callback_class, type):
1094
1184
  callback = callback_class(**callback_class_args)
1095
- elif not isinstance(callback_class, type):
1096
- callback = callback_class
1097
1185
 
1098
1186
  # Retrieve objects from shared memory
1099
- template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
1100
- target_ft = backend.sharedarr_to_arr(
1101
- ft_target_shape, ft_target_dtype, ft_target_buffer
1102
- )
1103
- target_ft2 = backend.sharedarr_to_arr(
1104
- ft_target_shape, ft_target_dtype, ft_target2_buffer
1105
- )
1106
- template_mask = backend.sharedarr_to_arr(
1107
- template_shape, template_dtype, template_mask_buffer
1108
- )
1109
- target_mask_ft = backend.sharedarr_to_arr(
1110
- ft_target_shape, ft_target_dtype, ft_target_mask_buffer
1111
- )
1112
- template_filter = backend.sharedarr_to_arr(
1113
- filter_shape, filter_dtype, filter_buffer
1114
- )
1187
+ template_buffer, template_shape, template_dtype = template
1188
+ template = backend.sharedarr_to_arr(*template)
1189
+ target_ft = backend.sharedarr_to_arr(*ft_target)
1190
+ target_ft2 = backend.sharedarr_to_arr(*ft_target2)
1191
+ template_mask = backend.sharedarr_to_arr(*template_mask)
1192
+ target_mask_ft = backend.sharedarr_to_arr(*ft_target_mask)
1193
+ template_filter = backend.sharedarr_to_arr(*template_filter)
1115
1194
 
1116
1195
  axes = tuple(range(template.ndim))
1117
1196
  eps = backend.eps(real_dtype)
@@ -1136,14 +1215,10 @@ def mcc_scoring(
1136
1215
  temp_fft=temp_ft,
1137
1216
  )
1138
1217
 
1139
- filter_template = backend.size(template_filter) != 0
1140
- template_filter_func = conditional_execute(backend.multiply, filter_template)
1141
-
1142
- axis = tuple(range(template.ndim))
1143
- fourier_shift = callback_class_args.get(
1144
- "fourier_shift", backend.zeros(template.ndim)
1218
+ template_filter_func = conditional_execute(
1219
+ apply_filter, backend.size(template_filter) != 1
1145
1220
  )
1146
- fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
1221
+ callback_func = conditional_execute(callback, callback_class is not None)
1147
1222
 
1148
1223
  # Calculate scores across all rotations
1149
1224
  for index in range(rotations.shape[0]):
@@ -1165,7 +1240,7 @@ def mcc_scoring(
1165
1240
 
1166
1241
  # template_rot_ft
1167
1242
  rfftn(template_rot, temp_ft)
1168
- template_filter_func(temp_ft, template_filter, out=temp_ft)
1243
+ template_filter_func(ft_template=temp_ft, template_filter=template_filter)
1169
1244
  irfftn(target_mask_ft * temp_ft, temp2)
1170
1245
  irfftn(target_ft * temp_ft, numerator)
1171
1246
 
@@ -1221,25 +1296,69 @@ def mcc_scoring(
1221
1296
  mask_overlap, axis=axes, keepdims=True
1222
1297
  )
1223
1298
  temp[mask_overlap < number_px_threshold] = 0.0
1224
- convolution_mode = kwargs.get("convolution_mode", "full")
1225
-
1226
- if fourier_shift_scores:
1227
- temp = backend.roll(temp, shift=fourier_shift, axis=axis)
1228
1299
 
1229
- score = apply_convolution_mode(
1230
- temp, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
1300
+ callback_func(
1301
+ temp,
1302
+ rotation_matrix=rotation,
1303
+ rotation_index=index,
1304
+ **callback_class_args,
1231
1305
  )
1232
- if callback_class is not None:
1233
- callback(
1234
- score,
1235
- rotation_matrix=rotation,
1236
- rotation_index=index,
1237
- **callback_class_args,
1238
- )
1239
1306
 
1240
1307
  return callback
1241
1308
 
1242
1309
 
1310
+ def _setup_template_filter(
1311
+ matching_data: MatchingData,
1312
+ rfftn: Callable,
1313
+ irfftn: Callable,
1314
+ fast_shape: Tuple[int],
1315
+ fast_ft_shape: Tuple[int],
1316
+ ):
1317
+ filter_template = isinstance(matching_data.template_filter, Compose)
1318
+ filter_target = isinstance(matching_data.target_filter, Compose)
1319
+
1320
+ template_filter = backend.full(
1321
+ shape=(1,), fill_value=1, dtype=backend._default_dtype
1322
+ )
1323
+
1324
+ if not filter_template and not filter_target:
1325
+ return template_filter
1326
+
1327
+ target_temp = backend.astype(
1328
+ backend.topleft_pad(matching_data.target, fast_shape), backend._default_dtype
1329
+ )
1330
+ target_temp_ft = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
1331
+ rfftn(target_temp, target_temp_ft)
1332
+
1333
+ if isinstance(matching_data.template_filter, Compose):
1334
+ template_filter = matching_data.template_filter(
1335
+ shape=fast_shape,
1336
+ return_real_fourier=True,
1337
+ shape_is_real_fourier=False,
1338
+ data_rfft=target_temp_ft,
1339
+ )
1340
+ template_filter = template_filter["data"]
1341
+ template_filter[tuple(0 for _ in range(template_filter.ndim))] = 0
1342
+
1343
+ if isinstance(matching_data.target_filter, Compose):
1344
+ target_filter = matching_data.target_filter(
1345
+ shape=fast_shape,
1346
+ return_real_fourier=True,
1347
+ shape_is_real_fourier=False,
1348
+ data_rfft=target_temp_ft,
1349
+ weight_type=None,
1350
+ )
1351
+ target_filter = target_filter["data"]
1352
+ backend.multiply(target_temp_ft, target_filter, out=target_temp_ft)
1353
+
1354
+ irfftn(target_temp_ft, target_temp)
1355
+ matching_data._target = backend.topleft_pad(
1356
+ target_temp, matching_data.target.shape
1357
+ )
1358
+
1359
+ return template_filter
1360
+
1361
+
1243
1362
  def device_memory_handler(func: Callable):
1244
1363
  """Decorator function providing SharedMemory Handler."""
1245
1364
 
@@ -1310,18 +1429,17 @@ def scan(
1310
1429
  Tuple
1311
1430
  The merged results from callback_class if provided otherwise None.
1312
1431
  """
1432
+ matching_data.to_backend()
1313
1433
  shape_diff = backend.subtract(
1314
- matching_data._target.shape, matching_data._template.shape
1434
+ matching_data._output_target_shape, matching_data._output_template_shape
1315
1435
  )
1436
+ shape_diff = backend.multiply(shape_diff, ~matching_data._batch_mask)
1316
1437
  if backend.sum(shape_diff < 0) and not pad_fourier:
1317
1438
  warnings.warn(
1318
1439
  "Target is larger than template and Fourier padding is turned off."
1319
1440
  " This can lead to shifted results. You can swap template and target, or"
1320
1441
  " zero-pad the target."
1321
1442
  )
1322
-
1323
- matching_data.to_backend()
1324
-
1325
1443
  fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
1326
1444
  pad_fourier=pad_fourier
1327
1445
  )
@@ -1334,6 +1452,15 @@ def scan(
1334
1452
  complex_dtype=matching_data._complex_dtype,
1335
1453
  fftargs=fftargs,
1336
1454
  )
1455
+
1456
+ # template_filter = _setup_template_filter(
1457
+ # matching_data=matching_data,
1458
+ # rfftn=rfftn,
1459
+ # irfftn=irfftn,
1460
+ # fast_shape=fast_shape,
1461
+ # fast_ft_shape=fast_ft_shape,
1462
+ # )
1463
+
1337
1464
  setup = matching_setup(
1338
1465
  rfftn=rfftn,
1339
1466
  irfftn=irfftn,
@@ -1351,6 +1478,7 @@ def scan(
1351
1478
  )
1352
1479
  rfftn, irfftn = None, None
1353
1480
 
1481
+
1354
1482
  template_filter, preprocessor = None, Preprocessor()
1355
1483
  for method, parameters in matching_data.template_filter.items():
1356
1484
  parameters["shape"] = fast_shape
@@ -1364,9 +1492,8 @@ def scan(
1364
1492
  template_filter = backend.full(
1365
1493
  shape=(1,), fill_value=1, dtype=backend._default_dtype
1366
1494
  )
1367
- else:
1368
- template_filter = backend.to_backend_array(template_filter)
1369
1495
 
1496
+ template_filter = backend.to_backend_array(template_filter)
1370
1497
  template_filter = backend.astype(template_filter, backend._default_dtype)
1371
1498
  template_filter_buffer = backend.arr_to_sharedarr(
1372
1499
  arr=template_filter,
@@ -1388,14 +1515,21 @@ def scan(
1388
1515
  callback_class = setup.pop("callback_class", callback_class)
1389
1516
  callback_class_args = setup.pop("callback_class_args", callback_class_args)
1390
1517
  callback_classes = [callback_class for _ in range(n_callback_classes)]
1518
+
1519
+ convolution_mode = "same"
1520
+ if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
1521
+ convolution_mode = "valid"
1522
+
1523
+
1524
+ callback_class_args["fourier_shift"] = fourier_shift
1525
+ callback_class_args["convolution_mode"] = convolution_mode
1526
+ callback_class_args["targetshape"] = setup["targetshape"]
1527
+ callback_class_args["templateshape"] = setup["templateshape"]
1528
+
1391
1529
  if callback_class == MaxScoreOverRotations:
1392
- score_space_shape = backend.subtract(
1393
- matching_data.target.shape,
1394
- matching_data._target_pad,
1395
- )
1396
1530
  callback_classes = [
1397
1531
  class_name(
1398
- score_space_shape=score_space_shape,
1532
+ score_space_shape=fast_shape,
1399
1533
  score_space_dtype=matching_data._default_dtype,
1400
1534
  shared_memory_handler=kwargs.get("shared_memory_handler", None),
1401
1535
  rotation_space_dtype=backend._default_dtype_int,
@@ -1435,10 +1569,16 @@ def scan(
1435
1569
  for index, rotation in enumerate(rotation_list)
1436
1570
  )
1437
1571
 
1572
+ callbacks = callbacks[0:n_callback_classes]
1438
1573
  callbacks = [
1439
- tuple(callback)
1440
- for callback in callbacks[0:n_callback_classes]
1441
- if callback is not None
1574
+ tuple(callback._postprocess(
1575
+ fourier_shift = fourier_shift,
1576
+ convolution_mode = convolution_mode,
1577
+ targetshape = setup["targetshape"],
1578
+ templateshape = setup["templateshape"],
1579
+ shared_memory_handler=kwargs.get("shared_memory_handler", None)
1580
+ )) if hasattr(callback, "_postprocess") else tuple(callback)
1581
+ for callback in callbacks if callback is not None
1442
1582
  ]
1443
1583
  backend.free_cache()
1444
1584
 
@@ -1549,11 +1689,13 @@ def scan_subsets(
1549
1689
  matching_data._target, matching_data._template = None, None
1550
1690
  matching_data._target_mask, matching_data._template_mask = None, None
1551
1691
 
1692
+ candidates = None
1552
1693
  if callback_class is not None:
1553
1694
  candidates = callback_class.merge(
1554
1695
  results, **callback_class_args, inner_merge=False
1555
1696
  )
1556
- return candidates
1697
+
1698
+ return candidates
1557
1699
 
1558
1700
 
1559
1701
  MATCHING_EXHAUSTIVE_REGISTER = {
@@ -1563,6 +1705,7 @@ MATCHING_EXHAUSTIVE_REGISTER = {
1563
1705
  "CAM": (cam_setup, corr_scoring),
1564
1706
  "FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
1565
1707
  "FLC": (flc_setup, flc_scoring),
1708
+ "FLC2": (flc_setup, flc_scoring2),
1566
1709
  "MCC": (mcc_setup, mcc_scoring),
1567
1710
  }
1568
1711