pytme 0.1.8__cp311-cp311-macosx_14_0_arm64.whl → 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {pytme-0.1.8.data → pytme-0.2.0b0.data}/scripts/match_template.py +148 -126
  2. pytme-0.2.0b0.data/scripts/postprocess.py +570 -0
  3. {pytme-0.1.8.data → pytme-0.2.0b0.data}/scripts/preprocessor_gui.py +244 -60
  4. {pytme-0.1.8.dist-info → pytme-0.2.0b0.dist-info}/METADATA +3 -1
  5. pytme-0.2.0b0.dist-info/RECORD +66 -0
  6. {pytme-0.1.8.dist-info → pytme-0.2.0b0.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +218 -0
  8. scripts/match_template.py +148 -126
  9. scripts/match_template_filters.py +852 -0
  10. scripts/postprocess.py +380 -435
  11. scripts/preprocessor_gui.py +244 -60
  12. scripts/refine_matches.py +218 -0
  13. tme/__init__.py +2 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +545 -78
  16. tme/backends/cupy_backend.py +80 -15
  17. tme/backends/npfftw_backend.py +33 -2
  18. tme/backends/pytorch_backend.py +15 -7
  19. tme/density.py +156 -63
  20. tme/extensions.cpython-311-darwin.so +0 -0
  21. tme/matching_constrained.py +195 -0
  22. tme/matching_data.py +76 -32
  23. tme/matching_exhaustive.py +366 -204
  24. tme/matching_memory.py +1 -0
  25. tme/matching_optimization.py +728 -651
  26. tme/matching_utils.py +152 -8
  27. tme/orientations.py +561 -0
  28. tme/preprocessor.py +21 -18
  29. tme/structure.py +2 -37
  30. pytme-0.1.8.data/scripts/postprocess.py +0 -625
  31. pytme-0.1.8.dist-info/RECORD +0 -61
  32. {pytme-0.1.8.data → pytme-0.2.0b0.data}/scripts/estimate_ram_usage.py +0 -0
  33. {pytme-0.1.8.data → pytme-0.2.0b0.data}/scripts/preprocess.py +0 -0
  34. {pytme-0.1.8.dist-info → pytme-0.2.0b0.dist-info}/LICENSE +0 -0
  35. {pytme-0.1.8.dist-info → pytme-0.2.0b0.dist-info}/entry_points.txt +0 -0
  36. {pytme-0.1.8.dist-info → pytme-0.2.0b0.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ from .matching_utils import (
25
25
  conditional_execute,
26
26
  )
27
27
  from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
28
+ # from .preprocessing import Compose
28
29
  from .preprocessor import Preprocessor
29
30
  from .matching_data import MatchingData
30
31
  from .backends import backend
@@ -43,6 +44,58 @@ def _run_inner(backend_name, backend_args, **kwargs):
43
44
  return scan(**kwargs)
44
45
 
45
46
 
47
+ def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
48
+ """
49
+ Standardizes the values in in template by subtracting the mean and dividing by the
50
+ standard deviation based on the elements in mask. Subsequently, the template is
51
+ multiplied by the mask.
52
+
53
+ Parameters
54
+ ----------
55
+ template : NDArray
56
+ The data array to be normalized. This array is modified in-place.
57
+ mask : NDArray
58
+ A boolean array of the same shape as `template`. True values indicate the positions in `template`
59
+ to consider for normalization.
60
+ mask_intensity : float
61
+ Mask intensity used to compute expectations.
62
+
63
+ References
64
+ ----------
65
+ .. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
66
+ and F. Förster, J. Struct. Biol. 178, 177 (2012).
67
+ .. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
68
+ R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
69
+ 13375 (2023)
70
+
71
+ Returns
72
+ -------
73
+ None
74
+ This function modifies `template` in-place and does not return any value.
75
+ """
76
+ masked_mean = backend.sum(backend.multiply(template, mask))
77
+ masked_mean = backend.divide(masked_mean, mask_intensity)
78
+ masked_std = backend.sum(backend.multiply(backend.square(template), mask))
79
+ masked_std = backend.subtract(
80
+ masked_std / mask_intensity, backend.square(masked_mean)
81
+ )
82
+ masked_std = backend.sqrt(backend.maximum(masked_std, 0))
83
+
84
+ backend.subtract(template, masked_mean, out=template)
85
+ backend.divide(template, masked_std, out=template)
86
+ backend.multiply(template, mask, out=template)
87
+ return None
88
+
89
+
90
+ def apply_filter(ft_template, template_filter):
91
+ # This is an approximation to applying the mask, irfftn, normalize, rfftn
92
+ std_before = backend.std(ft_template)
93
+ backend.multiply(ft_template, template_filter, out=ft_template)
94
+ backend.multiply(
95
+ ft_template, std_before / backend.std(ft_template), out=ft_template
96
+ )
97
+
98
+
46
99
  def cc_setup(
47
100
  rfftn: Callable,
48
101
  irfftn: Callable,
@@ -208,11 +261,16 @@ def corr_setup(
208
261
  target_pad, ft_target2, ft_window_template = None, None, None
209
262
 
210
263
  # Normalizing constants
211
- template_mean = backend.mean(template)
212
- template_volume = np.prod(template.shape)
264
+ n_observations = backend.sum(template_mask)
265
+ template_mean = backend.sum(backend.multiply(template, template_mask))
266
+ template_mean = backend.divide(template_mean, n_observations)
213
267
  template_ssd = backend.sum(
214
- backend.square(backend.subtract(template, template_mean))
268
+ backend.square(
269
+ backend.multiply(backend.multiply(template, template_mean), template_mask)
270
+ )
215
271
  )
272
+ template_volume = np.prod(template.shape)
273
+ backend.multiply(template, template_mask, out=template)
216
274
 
217
275
  # Final numerator is score - numerator2
218
276
  numerator2 = backend.multiply(target_window_sum, template_mean)
@@ -308,49 +366,6 @@ def cam_setup(**kwargs):
308
366
  return corr_setup(**kwargs)
309
367
 
310
368
 
311
- def _normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
312
- """
313
- Standardizes the values in in template by subtracting the mean and dividing by the
314
- standard deviation based on the elements in mask. Subsequently, the template is
315
- multiplied by the mask.
316
-
317
- Parameters
318
- ----------
319
- template : NDArray
320
- The data array to be normalized. This array is modified in-place.
321
- mask : NDArray
322
- A boolean array of the same shape as `template`. True values indicate the positions in `template`
323
- to consider for normalization.
324
- mask_intensity : float
325
- Mask intensity used to compute expectations.
326
-
327
- References
328
- ----------
329
- .. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
330
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
331
- .. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
332
- R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
333
- 13375 (2023)
334
-
335
- Returns
336
- -------
337
- None
338
- This function modifies `template` in-place and does not return any value.
339
- """
340
- masked_mean = backend.sum(backend.multiply(template, mask))
341
- masked_mean = backend.divide(masked_mean, mask_intensity)
342
- masked_std = backend.sum(backend.multiply(backend.square(template), mask))
343
- masked_std = backend.subtract(
344
- masked_std / mask_intensity, backend.square(masked_mean)
345
- )
346
- masked_std = backend.sqrt(backend.maximum(masked_std, 0))
347
-
348
- backend.subtract(template, masked_mean, out=template)
349
- backend.divide(template, masked_std, out=template)
350
- backend.multiply(template, mask, out=template)
351
- return None
352
-
353
-
354
369
  def flc_setup(
355
370
  rfftn: Callable,
356
371
  irfftn: Callable,
@@ -413,7 +428,7 @@ def flc_setup(
413
428
  arr=ft_target2, shared_memory_handler=shared_memory_handler
414
429
  )
415
430
 
416
- _normalize_under_mask(
431
+ normalize_under_mask(
417
432
  template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
418
433
  )
419
434
 
@@ -535,13 +550,16 @@ def flcSphericalMask_setup(
535
550
  backend.fill(temp2, 0)
536
551
  temp2[nonzero_indices] = 1 / temp[nonzero_indices]
537
552
 
538
- _normalize_under_mask(
553
+ normalize_under_mask(
539
554
  template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
540
555
  )
541
556
 
542
557
  template_buffer = backend.arr_to_sharedarr(
543
558
  arr=template, shared_memory_handler=shared_memory_handler
544
559
  )
560
+ template_mask_buffer = backend.arr_to_sharedarr(
561
+ arr=template_mask, shared_memory_handler=shared_memory_handler
562
+ )
545
563
  target_ft_buffer = backend.arr_to_sharedarr(
546
564
  arr=ft_target, shared_memory_handler=shared_memory_handler
547
565
  )
@@ -553,6 +571,7 @@ def flcSphericalMask_setup(
553
571
  )
554
572
 
555
573
  template_tuple = (template_buffer, template.shape, real_dtype)
574
+ template_mask_tuple = (template_mask_buffer, template.shape, real_dtype)
556
575
  target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
557
576
 
558
577
  inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
@@ -560,6 +579,7 @@ def flcSphericalMask_setup(
560
579
 
561
580
  ret = {
562
581
  "template": template_tuple,
582
+ "template_mask": template_mask_tuple,
563
583
  "ft_target": target_ft_tuple,
564
584
  "inv_denominator": inv_denominator_tuple,
565
585
  "numerator2": numerator2_tuple,
@@ -717,10 +737,6 @@ def corr_scoring(
717
737
  datatype.
718
738
  numerator2 : Tuple[type, Tuple[int], type]
719
739
  Tuple containing a pointer to the numerator2 data, its shape, and its datatype.
720
- targetshape : Tuple[int]
721
- The shape of the target.
722
- templateshape : Tuple[int]
723
- The shape of the template.
724
740
  fast_shape : Tuple[int]
725
741
  The shape for fast Fourier transform.
726
742
  fast_ft_shape : Tuple[int]
@@ -739,8 +755,6 @@ def corr_scoring(
739
755
  instantiable.
740
756
  interpolation_order : int
741
757
  The order of interpolation to be used while rotating the template.
742
- convolution_mode : str, optional
743
- Mode to use for convolution, default is "full".
744
758
  **kwargs :
745
759
  Additional arguments to be passed to the function.
746
760
 
@@ -756,31 +770,23 @@ def corr_scoring(
756
770
  :py:meth:`cam_setup`
757
771
  :py:meth:`flcSphericalMask_setup`
758
772
  """
759
- template_buffer, template_shape, template_dtype = template
760
- ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
761
- inv_denominator_buffer, inv_denominator_pointer_shape, _ = inv_denominator
762
- numerator2_buffer, numerator2_shape, _ = numerator2
763
- filter_buffer, filter_shape, filter_dtype = template_filter
764
-
773
+ callback = callback_class
765
774
  if callback_class is not None and isinstance(callback_class, type):
766
775
  callback = callback_class(**callback_class_args)
767
- elif not isinstance(callback_class, type):
768
- callback = callback_class
769
776
 
770
- # Retrieve objects from shared memory
771
- template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
772
- ft_target = backend.sharedarr_to_arr(
773
- ft_target_shape, ft_target_dtype, ft_target_buffer
774
- )
775
- inv_denominator = backend.sharedarr_to_arr(
776
- inv_denominator_pointer_shape, template_dtype, inv_denominator_buffer
777
- )
778
- numerator2 = backend.sharedarr_to_arr(
779
- numerator2_shape, template_dtype, numerator2_buffer
780
- )
781
- template_filter = backend.sharedarr_to_arr(
782
- filter_shape, filter_dtype, filter_buffer
783
- )
777
+ template_buffer, template_shape, template_dtype = template
778
+ template = backend.sharedarr_to_arr(template_buffer, template_shape, template_dtype)
779
+ ft_target = backend.sharedarr_to_arr(*ft_target)
780
+ inv_denominator = backend.sharedarr_to_arr(*inv_denominator)
781
+ numerator2 = backend.sharedarr_to_arr(*numerator2)
782
+ template_filter = backend.sharedarr_to_arr(*template_filter)
783
+
784
+ norm_template, template_mask, mask_sum = False, 1, 1
785
+ if "template_mask" in kwargs:
786
+ template_mask = backend.sharedarr_to_arr(
787
+ kwargs["template_mask"][0], template_shape, template_dtype
788
+ )
789
+ norm_template, mask_sum = True, backend.sum(template_mask)
784
790
 
785
791
  arr = backend.preallocate_array(fast_shape, real_dtype)
786
792
  ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
@@ -799,17 +805,16 @@ def corr_scoring(
799
805
  norm_denominator = (backend.sum(inv_denominator) != 1) & (
800
806
  backend.size(inv_denominator) != 1
801
807
  )
802
- filter_template = backend.size(template_filter) != 0
803
808
 
809
+ norm_template = conditional_execute(normalize_under_mask, norm_template)
810
+ callback_func = conditional_execute(callback, callback_class is not None)
804
811
  norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
805
812
  norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
806
- template_filter_func = conditional_execute(backend.multiply, filter_template)
807
-
808
- axis = tuple(range(arr.ndim))
809
- fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
810
- fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
813
+ template_filter_func = conditional_execute(
814
+ apply_filter, backend.size(template_filter) != 1
815
+ )
811
816
 
812
- template_sum = backend.sum(template)
817
+ unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
813
818
  for index in range(rotations.shape[0]):
814
819
  rotation = rotations[index]
815
820
  backend.fill(arr, 0)
@@ -820,33 +825,24 @@ def corr_scoring(
820
825
  use_geometric_center=False,
821
826
  order=interpolation_order,
822
827
  )
823
- rotation_norm = template_sum / backend.sum(arr)
824
- backend.multiply(arr, rotation_norm, out=arr)
825
828
 
826
- rfftn(arr, ft_temp)
827
- template_filter_func(ft_temp, template_filter, out=ft_temp)
829
+ norm_template(arr[unpadded_slice], template_mask, mask_sum)
828
830
 
831
+ rfftn(arr, ft_temp)
832
+ template_filter_func(ft_template=ft_temp, template_filter=template_filter)
829
833
  backend.multiply(ft_target, ft_temp, out=ft_temp)
830
834
  irfftn(ft_temp, arr)
831
835
 
832
836
  norm_func_numerator(arr, numerator2, out=arr)
833
837
  norm_func_denominator(arr, inv_denominator, out=arr)
834
838
 
835
- if fourier_shift_scores:
836
- arr = backend.roll(arr, shift=fourier_shift, axis=axis)
837
-
838
- score = apply_convolution_mode(
839
- arr, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
839
+ callback_func(
840
+ arr,
841
+ rotation_matrix=rotation,
842
+ rotation_index=index,
843
+ **callback_class_args,
840
844
  )
841
845
 
842
- if callback_class is not None:
843
- callback(
844
- score,
845
- rotation_matrix=rotation,
846
- rotation_index=index,
847
- **callback_class_args,
848
- )
849
-
850
846
  return callback
851
847
 
852
848
 
@@ -894,32 +890,15 @@ def flc_scoring(
894
890
  .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
895
891
  and F. Förster, J. Struct. Biol. 178, 177 (2012).
896
892
  """
897
- template_buffer, template_shape, template_dtype = template
898
- template_mask_buffer, *_ = template_mask
899
- filter_buffer, filter_shape, filter_dtype = template_filter
900
-
901
- ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
902
- ft_target2_buffer, *_ = ft_target2
903
-
893
+ callback = callback_class
904
894
  if callback_class is not None and isinstance(callback_class, type):
905
895
  callback = callback_class(**callback_class_args)
906
- elif not isinstance(callback_class, type):
907
- callback = callback_class
908
896
 
909
- # Retrieve objects from shared memory
910
- template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
911
- template_mask = backend.sharedarr_to_arr(
912
- template_shape, template_dtype, template_mask_buffer
913
- )
914
- ft_target = backend.sharedarr_to_arr(
915
- ft_target_shape, ft_target_dtype, ft_target_buffer
916
- )
917
- ft_target2 = backend.sharedarr_to_arr(
918
- ft_target_shape, ft_target_dtype, ft_target2_buffer
919
- )
920
- template_filter = backend.sharedarr_to_arr(
921
- filter_shape, filter_dtype, filter_buffer
922
- )
897
+ template = backend.sharedarr_to_arr(*template)
898
+ template_mask = backend.sharedarr_to_arr(*template_mask)
899
+ ft_target = backend.sharedarr_to_arr(*ft_target)
900
+ ft_target2 = backend.sharedarr_to_arr(*ft_target2)
901
+ template_filter = backend.sharedarr_to_arr(*template_filter)
923
902
 
924
903
  arr = backend.preallocate_array(fast_shape, real_dtype)
925
904
  temp = backend.preallocate_array(fast_shape, real_dtype)
@@ -938,12 +917,10 @@ def flc_scoring(
938
917
  temp_fft=ft_temp,
939
918
  )
940
919
  eps = backend.eps(real_dtype)
941
- filter_template = backend.size(template_filter) != 0
942
- template_filter_func = conditional_execute(backend.multiply, filter_template)
943
-
944
- axis = tuple(range(arr.ndim))
945
- fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
946
- fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
920
+ template_filter_func = conditional_execute(
921
+ apply_filter, backend.size(template_filter) != 1
922
+ )
923
+ callback_func = conditional_execute(callback, callback_class is not None)
947
924
 
948
925
  unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
949
926
  for index in range(rotations.shape[0]):
@@ -962,7 +939,7 @@ def flc_scoring(
962
939
  # Given the amount of FFTs, might aswell normalize properly
963
940
  n_observations = backend.sum(temp)
964
941
 
965
- _normalize_under_mask(
942
+ normalize_under_mask(
966
943
  template=arr[unpadded_slice],
967
944
  mask=temp[unpadded_slice],
968
945
  mask_intensity=n_observations,
@@ -985,7 +962,7 @@ def flc_scoring(
985
962
  backend.multiply(temp, n_observations, out=temp)
986
963
 
987
964
  rfftn(arr, ft_temp)
988
- template_filter_func(ft_temp, template_filter, out=ft_temp)
965
+ template_filter_func(ft_template=ft_temp, template_filter=template_filter)
989
966
  backend.multiply(ft_target, ft_temp, out=ft_temp)
990
967
  irfftn(ft_temp, arr)
991
968
 
@@ -994,23 +971,161 @@ def flc_scoring(
994
971
  backend.fill(temp2, 0)
995
972
  temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
996
973
 
997
- convolution_mode = kwargs.get("convolution_mode", "full")
974
+ callback_func(
975
+ temp2,
976
+ rotation_matrix=rotation,
977
+ rotation_index=index,
978
+ **callback_class_args,
979
+ )
998
980
 
999
- if fourier_shift_scores:
1000
- temp2 = backend.roll(temp2, shift=fourier_shift, axis=axis)
981
+ return callback
1001
982
 
1002
- score = apply_convolution_mode(
1003
- temp2, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
983
+
984
+ def flc_scoring2(
985
+ template: Tuple[type, Tuple[int], type],
986
+ template_mask: Tuple[type, Tuple[int], type],
987
+ ft_target: Tuple[type, Tuple[int], type],
988
+ ft_target2: Tuple[type, Tuple[int], type],
989
+ template_filter: Tuple[type, Tuple[int], type],
990
+ targetshape: Tuple[int],
991
+ templateshape: Tuple[int],
992
+ fast_shape: Tuple[int],
993
+ fast_ft_shape: Tuple[int],
994
+ rotations: NDArray,
995
+ real_dtype: type,
996
+ complex_dtype: type,
997
+ callback_class: CallbackClass,
998
+ callback_class_args: Dict,
999
+ interpolation_order: int,
1000
+ **kwargs,
1001
+ ) -> CallbackClass:
1002
+ """
1003
+ Computes a normalized cross-correlation score of a target f a template g
1004
+ and a mask m:
1005
+
1006
+ .. math::
1007
+
1008
+ \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
1009
+ {N_m * \\sqrt{
1010
+ \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
1011
+ }
1012
+
1013
+ Where:
1014
+
1015
+ .. math::
1016
+
1017
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
1018
+
1019
+ and Nm is the number of voxels within the template mask m.
1020
+
1021
+ References
1022
+ ----------
1023
+ .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
1024
+ Microsc. Microanal. 26, 2516 (2020)
1025
+ .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
1026
+ and F. Förster, J. Struct. Biol. 178, 177 (2012).
1027
+ """
1028
+ callback = callback_class
1029
+ if callback_class is not None and isinstance(callback_class, type):
1030
+ callback = callback_class(**callback_class_args)
1031
+
1032
+ # Retrieve objects from shared memory
1033
+ template = backend.sharedarr_to_arr(*template)
1034
+ template_mask = backend.sharedarr_to_arr(*template_mask)
1035
+ ft_target = backend.sharedarr_to_arr(*ft_target)
1036
+ ft_target2 = backend.sharedarr_to_arr(*ft_target2)
1037
+ template_filter = backend.sharedarr_to_arr(*template_filter)
1038
+
1039
+ arr = backend.preallocate_array(fast_shape, real_dtype)
1040
+ temp = backend.preallocate_array(fast_shape, real_dtype)
1041
+ temp2 = backend.preallocate_array(fast_shape, real_dtype)
1042
+
1043
+ ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
1044
+ ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
1045
+
1046
+ eps = backend.eps(real_dtype)
1047
+ template_filter_func = conditional_execute(
1048
+ apply_filter, backend.size(template_filter) != 1
1049
+ )
1050
+ callback_func = conditional_execute(callback, callback_class is not None)
1051
+
1052
+ squeeze_axis = tuple(i for i, x in enumerate(template.shape) if x == 1)
1053
+ squeeze = tuple(
1054
+ slice(0, stop) if i not in squeeze_axis else 0
1055
+ for i, stop in enumerate(template.shape)
1056
+ )
1057
+ squeeze_fast = tuple(
1058
+ slice(0, stop) if i not in squeeze_axis else 0
1059
+ for i, stop in enumerate(fast_shape)
1060
+ )
1061
+ squeeze_fast_ft = tuple(
1062
+ slice(0, stop) if i not in squeeze_axis else 0
1063
+ for i, stop in enumerate(fast_ft_shape)
1064
+ )
1065
+
1066
+ rfftn, irfftn = backend.build_fft(
1067
+ fast_shape=temp[squeeze_fast].shape,
1068
+ fast_ft_shape=fast_ft_shape,
1069
+ real_dtype=real_dtype,
1070
+ complex_dtype=complex_dtype,
1071
+ fftargs=kwargs.get("fftargs", {}),
1072
+ inverse_fast_shape=fast_shape,
1073
+ temp_real=arr[squeeze_fast],
1074
+ temp_fft=ft_temp,
1075
+ )
1076
+ for index in range(rotations.shape[0]):
1077
+ rotation = rotations[index]
1078
+ backend.fill(arr, 0)
1079
+ backend.fill(temp, 0)
1080
+ backend.rotate_array(
1081
+ arr=template[squeeze],
1082
+ arr_mask=template_mask[squeeze],
1083
+ rotation_matrix=rotation,
1084
+ out=arr[squeeze],
1085
+ out_mask=temp[squeeze],
1086
+ use_geometric_center=False,
1087
+ order=interpolation_order,
1004
1088
  )
1089
+ # Given the amount of FFTs, might aswell normalize properly
1090
+ n_observations = backend.sum(temp)
1005
1091
 
1006
- if callback_class is not None:
1007
- callback(
1008
- score,
1009
- rotation_matrix=rotation,
1010
- rotation_index=index,
1011
- **callback_class_args,
1012
- )
1092
+ normalize_under_mask(
1093
+ template=arr[squeeze],
1094
+ mask=temp[squeeze],
1095
+ mask_intensity=n_observations,
1096
+ )
1097
+ rfftn(temp[squeeze_fast], ft_temp[squeeze_fast_ft])
1098
+
1099
+ backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
1100
+ irfftn(ft_denom, temp)
1101
+ backend.divide(temp, n_observations, out=temp)
1102
+ backend.square(temp, out=temp)
1103
+
1104
+ backend.multiply(ft_target2, ft_temp[squeeze_fast_ft], out=ft_denom)
1105
+ irfftn(ft_denom, temp2)
1106
+ backend.divide(temp2, n_observations, out=temp2)
1107
+
1108
+ backend.subtract(temp2, temp, out=temp)
1109
+ backend.maximum(temp, 0.0, out=temp)
1110
+ backend.sqrt(temp, out=temp)
1111
+ backend.multiply(temp, n_observations, out=temp)
1112
+
1113
+ rfftn(arr[squeeze_fast], ft_temp[squeeze_fast_ft])
1114
+ template_filter_func(ft_template=ft_temp, template_filter=template_filter)
1115
+ backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
1116
+ irfftn(ft_denom, arr)
1013
1117
 
1118
+ tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
1119
+ nonzero_indices = temp > tol
1120
+ backend.fill(temp2, 0)
1121
+ temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
1122
+
1123
+ callback_func(
1124
+ temp2,
1125
+ rotation_matrix=rotation,
1126
+ rotation_index=index,
1127
+ **callback_class_args,
1128
+ )
1014
1129
  return callback
1015
1130
 
1016
1131
 
@@ -1064,35 +1179,18 @@ def mcc_scoring(
1064
1179
  --------
1065
1180
  :py:class:`tme.matching_optimization.MaskedCrossCorrelation`
1066
1181
  """
1067
- template_buffer, template_shape, template_dtype = template
1068
- ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
1069
- ft_target2_buffer, ft_target_shape, ft_target_dtype = ft_target2
1070
- template_mask_buffer, _, _ = template
1071
- ft_target_mask_buffer, _, _ = ft_target
1072
- filter_buffer, filter_shape, filter_dtype = template_filter
1073
-
1182
+ callback = callback_class
1074
1183
  if callback_class is not None and isinstance(callback_class, type):
1075
1184
  callback = callback_class(**callback_class_args)
1076
- elif not isinstance(callback_class, type):
1077
- callback = callback_class
1078
1185
 
1079
1186
  # Retrieve objects from shared memory
1080
- template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
1081
- target_ft = backend.sharedarr_to_arr(
1082
- ft_target_shape, ft_target_dtype, ft_target_buffer
1083
- )
1084
- target_ft2 = backend.sharedarr_to_arr(
1085
- ft_target_shape, ft_target_dtype, ft_target2_buffer
1086
- )
1087
- template_mask = backend.sharedarr_to_arr(
1088
- template_shape, template_dtype, template_mask_buffer
1089
- )
1090
- target_mask_ft = backend.sharedarr_to_arr(
1091
- ft_target_shape, ft_target_dtype, ft_target_mask_buffer
1092
- )
1093
- template_filter = backend.sharedarr_to_arr(
1094
- filter_shape, filter_dtype, filter_buffer
1095
- )
1187
+ template_buffer, template_shape, template_dtype = template
1188
+ template = backend.sharedarr_to_arr(*template)
1189
+ target_ft = backend.sharedarr_to_arr(*ft_target)
1190
+ target_ft2 = backend.sharedarr_to_arr(*ft_target2)
1191
+ template_mask = backend.sharedarr_to_arr(*template_mask)
1192
+ target_mask_ft = backend.sharedarr_to_arr(*ft_target_mask)
1193
+ template_filter = backend.sharedarr_to_arr(*template_filter)
1096
1194
 
1097
1195
  axes = tuple(range(template.ndim))
1098
1196
  eps = backend.eps(real_dtype)
@@ -1117,14 +1215,10 @@ def mcc_scoring(
1117
1215
  temp_fft=temp_ft,
1118
1216
  )
1119
1217
 
1120
- filter_template = backend.size(template_filter) != 0
1121
- template_filter_func = conditional_execute(backend.multiply, filter_template)
1122
-
1123
- axis = tuple(range(template.ndim))
1124
- fourier_shift = callback_class_args.get(
1125
- "fourier_shift", backend.zeros(template.ndim)
1218
+ template_filter_func = conditional_execute(
1219
+ apply_filter, backend.size(template_filter) != 1
1126
1220
  )
1127
- fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
1221
+ callback_func = conditional_execute(callback, callback_class is not None)
1128
1222
 
1129
1223
  # Calculate scores across all rotations
1130
1224
  for index in range(rotations.shape[0]):
@@ -1146,7 +1240,7 @@ def mcc_scoring(
1146
1240
 
1147
1241
  # template_rot_ft
1148
1242
  rfftn(template_rot, temp_ft)
1149
- template_filter_func(temp_ft, template_filter, out=temp_ft)
1243
+ template_filter_func(ft_template=temp_ft, template_filter=template_filter)
1150
1244
  irfftn(target_mask_ft * temp_ft, temp2)
1151
1245
  irfftn(target_ft * temp_ft, numerator)
1152
1246
 
@@ -1202,25 +1296,69 @@ def mcc_scoring(
1202
1296
  mask_overlap, axis=axes, keepdims=True
1203
1297
  )
1204
1298
  temp[mask_overlap < number_px_threshold] = 0.0
1205
- convolution_mode = kwargs.get("convolution_mode", "full")
1206
1299
 
1207
- if fourier_shift_scores:
1208
- temp = backend.roll(temp, shift=fourier_shift, axis=axis)
1209
-
1210
- score = apply_convolution_mode(
1211
- temp, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
1300
+ callback_func(
1301
+ temp,
1302
+ rotation_matrix=rotation,
1303
+ rotation_index=index,
1304
+ **callback_class_args,
1212
1305
  )
1213
- if callback_class is not None:
1214
- callback(
1215
- score,
1216
- rotation_matrix=rotation,
1217
- rotation_index=index,
1218
- **callback_class_args,
1219
- )
1220
1306
 
1221
1307
  return callback
1222
1308
 
1223
1309
 
1310
+ def _setup_template_filter(
1311
+ matching_data: MatchingData,
1312
+ rfftn: Callable,
1313
+ irfftn: Callable,
1314
+ fast_shape: Tuple[int],
1315
+ fast_ft_shape: Tuple[int],
1316
+ ):
1317
+ filter_template = isinstance(matching_data.template_filter, Compose)
1318
+ filter_target = isinstance(matching_data.target_filter, Compose)
1319
+
1320
+ template_filter = backend.full(
1321
+ shape=(1,), fill_value=1, dtype=backend._default_dtype
1322
+ )
1323
+
1324
+ if not filter_template and not filter_target:
1325
+ return template_filter
1326
+
1327
+ target_temp = backend.astype(
1328
+ backend.topleft_pad(matching_data.target, fast_shape), backend._default_dtype
1329
+ )
1330
+ target_temp_ft = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
1331
+ rfftn(target_temp, target_temp_ft)
1332
+
1333
+ if isinstance(matching_data.template_filter, Compose):
1334
+ template_filter = matching_data.template_filter(
1335
+ shape=fast_shape,
1336
+ return_real_fourier=True,
1337
+ shape_is_real_fourier=False,
1338
+ data_rfft=target_temp_ft,
1339
+ )
1340
+ template_filter = template_filter["data"]
1341
+ template_filter[tuple(0 for _ in range(template_filter.ndim))] = 0
1342
+
1343
+ if isinstance(matching_data.target_filter, Compose):
1344
+ target_filter = matching_data.target_filter(
1345
+ shape=fast_shape,
1346
+ return_real_fourier=True,
1347
+ shape_is_real_fourier=False,
1348
+ data_rfft=target_temp_ft,
1349
+ weight_type=None,
1350
+ )
1351
+ target_filter = target_filter["data"]
1352
+ backend.multiply(target_temp_ft, target_filter, out=target_temp_ft)
1353
+
1354
+ irfftn(target_temp_ft, target_temp)
1355
+ matching_data._target = backend.topleft_pad(
1356
+ target_temp, matching_data.target.shape
1357
+ )
1358
+
1359
+ return template_filter
1360
+
1361
+
1224
1362
  def device_memory_handler(func: Callable):
1225
1363
  """Decorator function providing SharedMemory Handler."""
1226
1364
 
@@ -1291,18 +1429,17 @@ def scan(
1291
1429
  Tuple
1292
1430
  The merged results from callback_class if provided otherwise None.
1293
1431
  """
1432
+ matching_data.to_backend()
1294
1433
  shape_diff = backend.subtract(
1295
- matching_data._target.shape, matching_data._template.shape
1434
+ matching_data._output_target_shape, matching_data._output_template_shape
1296
1435
  )
1436
+ shape_diff = backend.multiply(shape_diff, ~matching_data._batch_mask)
1297
1437
  if backend.sum(shape_diff < 0) and not pad_fourier:
1298
1438
  warnings.warn(
1299
1439
  "Target is larger than template and Fourier padding is turned off."
1300
1440
  " This can lead to shifted results. You can swap template and target, or"
1301
1441
  " zero-pad the target."
1302
1442
  )
1303
-
1304
- matching_data.to_backend()
1305
-
1306
1443
  fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
1307
1444
  pad_fourier=pad_fourier
1308
1445
  )
@@ -1315,6 +1452,15 @@ def scan(
1315
1452
  complex_dtype=matching_data._complex_dtype,
1316
1453
  fftargs=fftargs,
1317
1454
  )
1455
+
1456
+ # template_filter = _setup_template_filter(
1457
+ # matching_data=matching_data,
1458
+ # rfftn=rfftn,
1459
+ # irfftn=irfftn,
1460
+ # fast_shape=fast_shape,
1461
+ # fast_ft_shape=fast_ft_shape,
1462
+ # )
1463
+
1318
1464
  setup = matching_setup(
1319
1465
  rfftn=rfftn,
1320
1466
  irfftn=irfftn,
@@ -1332,6 +1478,7 @@ def scan(
1332
1478
  )
1333
1479
  rfftn, irfftn = None, None
1334
1480
 
1481
+
1335
1482
  template_filter, preprocessor = None, Preprocessor()
1336
1483
  for method, parameters in matching_data.template_filter.items():
1337
1484
  parameters["shape"] = fast_shape
@@ -1345,9 +1492,8 @@ def scan(
1345
1492
  template_filter = backend.full(
1346
1493
  shape=(1,), fill_value=1, dtype=backend._default_dtype
1347
1494
  )
1348
- else:
1349
- template_filter = backend.to_backend_array(template_filter)
1350
1495
 
1496
+ template_filter = backend.to_backend_array(template_filter)
1351
1497
  template_filter = backend.astype(template_filter, backend._default_dtype)
1352
1498
  template_filter_buffer = backend.arr_to_sharedarr(
1353
1499
  arr=template_filter,
@@ -1369,14 +1515,21 @@ def scan(
1369
1515
  callback_class = setup.pop("callback_class", callback_class)
1370
1516
  callback_class_args = setup.pop("callback_class_args", callback_class_args)
1371
1517
  callback_classes = [callback_class for _ in range(n_callback_classes)]
1518
+
1519
+ convolution_mode = "same"
1520
+ if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
1521
+ convolution_mode = "valid"
1522
+
1523
+
1524
+ callback_class_args["fourier_shift"] = fourier_shift
1525
+ callback_class_args["convolution_mode"] = convolution_mode
1526
+ callback_class_args["targetshape"] = setup["targetshape"]
1527
+ callback_class_args["templateshape"] = setup["templateshape"]
1528
+
1372
1529
  if callback_class == MaxScoreOverRotations:
1373
- score_space_shape = backend.subtract(
1374
- matching_data.target.shape,
1375
- matching_data._target_pad,
1376
- )
1377
1530
  callback_classes = [
1378
1531
  class_name(
1379
- score_space_shape=score_space_shape,
1532
+ score_space_shape=fast_shape,
1380
1533
  score_space_dtype=matching_data._default_dtype,
1381
1534
  shared_memory_handler=kwargs.get("shared_memory_handler", None),
1382
1535
  rotation_space_dtype=backend._default_dtype_int,
@@ -1416,10 +1569,16 @@ def scan(
1416
1569
  for index, rotation in enumerate(rotation_list)
1417
1570
  )
1418
1571
 
1572
+ callbacks = callbacks[0:n_callback_classes]
1419
1573
  callbacks = [
1420
- tuple(callback)
1421
- for callback in callbacks[0:n_callback_classes]
1422
- if callback is not None
1574
+ tuple(callback._postprocess(
1575
+ fourier_shift = fourier_shift,
1576
+ convolution_mode = convolution_mode,
1577
+ targetshape = setup["targetshape"],
1578
+ templateshape = setup["templateshape"],
1579
+ shared_memory_handler=kwargs.get("shared_memory_handler", None)
1580
+ )) if hasattr(callback, "_postprocess") else tuple(callback)
1581
+ for callback in callbacks if callback is not None
1423
1582
  ]
1424
1583
  backend.free_cache()
1425
1584
 
@@ -1530,11 +1689,13 @@ def scan_subsets(
1530
1689
  matching_data._target, matching_data._template = None, None
1531
1690
  matching_data._target_mask, matching_data._template_mask = None, None
1532
1691
 
1692
+ candidates = None
1533
1693
  if callback_class is not None:
1534
1694
  candidates = callback_class.merge(
1535
1695
  results, **callback_class_args, inner_merge=False
1536
1696
  )
1537
- return candidates
1697
+
1698
+ return candidates
1538
1699
 
1539
1700
 
1540
1701
  MATCHING_EXHAUSTIVE_REGISTER = {
@@ -1544,6 +1705,7 @@ MATCHING_EXHAUSTIVE_REGISTER = {
1544
1705
  "CAM": (cam_setup, corr_scoring),
1545
1706
  "FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
1546
1707
  "FLC": (flc_setup, flc_scoring),
1708
+ "FLC2": (flc_setup, flc_scoring2),
1547
1709
  "MCC": (mcc_setup, mcc_scoring),
1548
1710
  }
1549
1711