pytme 0.1.8__cp311-cp311-macosx_14_0_arm64.whl → 0.2.0__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. pytme-0.2.0.data/scripts/match_template.py +1019 -0
  2. pytme-0.2.0.data/scripts/postprocess.py +570 -0
  3. {pytme-0.1.8.data → pytme-0.2.0.data}/scripts/preprocessor_gui.py +244 -60
  4. {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/METADATA +3 -1
  5. pytme-0.2.0.dist-info/RECORD +72 -0
  6. {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +218 -0
  8. scripts/match_template.py +459 -218
  9. pytme-0.1.8.data/scripts/match_template.py → scripts/match_template_filters.py +459 -218
  10. scripts/postprocess.py +380 -435
  11. scripts/preprocessor_gui.py +244 -60
  12. scripts/refine_matches.py +218 -0
  13. tme/__init__.py +2 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +533 -78
  16. tme/backends/cupy_backend.py +80 -15
  17. tme/backends/npfftw_backend.py +35 -6
  18. tme/backends/pytorch_backend.py +15 -7
  19. tme/density.py +173 -78
  20. tme/extensions.cpython-311-darwin.so +0 -0
  21. tme/matching_constrained.py +195 -0
  22. tme/matching_data.py +78 -32
  23. tme/matching_exhaustive.py +369 -221
  24. tme/matching_memory.py +1 -0
  25. tme/matching_optimization.py +753 -649
  26. tme/matching_utils.py +152 -8
  27. tme/orientations.py +561 -0
  28. tme/preprocessing/__init__.py +2 -0
  29. tme/preprocessing/_utils.py +176 -0
  30. tme/preprocessing/composable_filter.py +30 -0
  31. tme/preprocessing/compose.py +52 -0
  32. tme/preprocessing/frequency_filters.py +322 -0
  33. tme/preprocessing/tilt_series.py +967 -0
  34. tme/preprocessor.py +35 -25
  35. tme/structure.py +2 -37
  36. pytme-0.1.8.data/scripts/postprocess.py +0 -625
  37. pytme-0.1.8.dist-info/RECORD +0 -61
  38. {pytme-0.1.8.data → pytme-0.2.0.data}/scripts/estimate_ram_usage.py +0 -0
  39. {pytme-0.1.8.data → pytme-0.2.0.data}/scripts/preprocess.py +0 -0
  40. {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/LICENSE +0 -0
  41. {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/entry_points.txt +0 -0
  42. {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/top_level.txt +0 -0
@@ -19,13 +19,12 @@ from scipy.ndimage import laplace
19
19
 
20
20
  from .analyzer import MaxScoreOverRotations
21
21
  from .matching_utils import (
22
- apply_convolution_mode,
23
22
  handle_traceback,
24
23
  split_numpy_array_slices,
25
24
  conditional_execute,
26
25
  )
27
26
  from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
28
- from .preprocessor import Preprocessor
27
+ from .preprocessing import Compose
29
28
  from .matching_data import MatchingData
30
29
  from .backends import backend
31
30
  from .types import NDArray, CallbackClass
@@ -43,6 +42,58 @@ def _run_inner(backend_name, backend_args, **kwargs):
43
42
  return scan(**kwargs)
44
43
 
45
44
 
45
+ def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
46
+ """
47
+ Standardizes the values in in template by subtracting the mean and dividing by the
48
+ standard deviation based on the elements in mask. Subsequently, the template is
49
+ multiplied by the mask.
50
+
51
+ Parameters
52
+ ----------
53
+ template : NDArray
54
+ The data array to be normalized. This array is modified in-place.
55
+ mask : NDArray
56
+ A boolean array of the same shape as `template`. True values indicate the positions in `template`
57
+ to consider for normalization.
58
+ mask_intensity : float
59
+ Mask intensity used to compute expectations.
60
+
61
+ References
62
+ ----------
63
+ .. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
64
+ and F. Förster, J. Struct. Biol. 178, 177 (2012).
65
+ .. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
66
+ R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
67
+ 13375 (2023)
68
+
69
+ Returns
70
+ -------
71
+ None
72
+ This function modifies `template` in-place and does not return any value.
73
+ """
74
+ masked_mean = backend.sum(backend.multiply(template, mask))
75
+ masked_mean = backend.divide(masked_mean, mask_intensity)
76
+ masked_std = backend.sum(backend.multiply(backend.square(template), mask))
77
+ masked_std = backend.subtract(
78
+ masked_std / mask_intensity, backend.square(masked_mean)
79
+ )
80
+ masked_std = backend.sqrt(backend.maximum(masked_std, 0))
81
+
82
+ backend.subtract(template, masked_mean, out=template)
83
+ backend.divide(template, masked_std, out=template)
84
+ backend.multiply(template, mask, out=template)
85
+ return None
86
+
87
+
88
+ def apply_filter(ft_template, template_filter):
89
+ # This is an approximation to applying the mask, irfftn, normalize, rfftn
90
+ std_before = backend.std(ft_template)
91
+ backend.multiply(ft_template, template_filter, out=ft_template)
92
+ backend.multiply(
93
+ ft_template, std_before / backend.std(ft_template), out=ft_template
94
+ )
95
+
96
+
46
97
  def cc_setup(
47
98
  rfftn: Callable,
48
99
  irfftn: Callable,
@@ -208,11 +259,16 @@ def corr_setup(
208
259
  target_pad, ft_target2, ft_window_template = None, None, None
209
260
 
210
261
  # Normalizing constants
211
- template_mean = backend.mean(template)
212
- template_volume = np.prod(template.shape)
262
+ n_observations = backend.sum(template_mask)
263
+ template_mean = backend.sum(backend.multiply(template, template_mask))
264
+ template_mean = backend.divide(template_mean, n_observations)
213
265
  template_ssd = backend.sum(
214
- backend.square(backend.subtract(template, template_mean))
266
+ backend.square(
267
+ backend.multiply(backend.multiply(template, template_mean), template_mask)
268
+ )
215
269
  )
270
+ template_volume = np.prod(template.shape)
271
+ backend.multiply(template, template_mask, out=template)
216
272
 
217
273
  # Final numerator is score - numerator2
218
274
  numerator2 = backend.multiply(target_window_sum, template_mean)
@@ -308,49 +364,6 @@ def cam_setup(**kwargs):
308
364
  return corr_setup(**kwargs)
309
365
 
310
366
 
311
- def _normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
312
- """
313
- Standardizes the values in in template by subtracting the mean and dividing by the
314
- standard deviation based on the elements in mask. Subsequently, the template is
315
- multiplied by the mask.
316
-
317
- Parameters
318
- ----------
319
- template : NDArray
320
- The data array to be normalized. This array is modified in-place.
321
- mask : NDArray
322
- A boolean array of the same shape as `template`. True values indicate the positions in `template`
323
- to consider for normalization.
324
- mask_intensity : float
325
- Mask intensity used to compute expectations.
326
-
327
- References
328
- ----------
329
- .. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
330
- and F. Förster, J. Struct. Biol. 178, 177 (2012).
331
- .. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
332
- R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
333
- 13375 (2023)
334
-
335
- Returns
336
- -------
337
- None
338
- This function modifies `template` in-place and does not return any value.
339
- """
340
- masked_mean = backend.sum(backend.multiply(template, mask))
341
- masked_mean = backend.divide(masked_mean, mask_intensity)
342
- masked_std = backend.sum(backend.multiply(backend.square(template), mask))
343
- masked_std = backend.subtract(
344
- masked_std / mask_intensity, backend.square(masked_mean)
345
- )
346
- masked_std = backend.sqrt(backend.maximum(masked_std, 0))
347
-
348
- backend.subtract(template, masked_mean, out=template)
349
- backend.divide(template, masked_std, out=template)
350
- backend.multiply(template, mask, out=template)
351
- return None
352
-
353
-
354
367
  def flc_setup(
355
368
  rfftn: Callable,
356
369
  irfftn: Callable,
@@ -413,7 +426,7 @@ def flc_setup(
413
426
  arr=ft_target2, shared_memory_handler=shared_memory_handler
414
427
  )
415
428
 
416
- _normalize_under_mask(
429
+ normalize_under_mask(
417
430
  template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
418
431
  )
419
432
 
@@ -535,13 +548,16 @@ def flcSphericalMask_setup(
535
548
  backend.fill(temp2, 0)
536
549
  temp2[nonzero_indices] = 1 / temp[nonzero_indices]
537
550
 
538
- _normalize_under_mask(
551
+ normalize_under_mask(
539
552
  template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
540
553
  )
541
554
 
542
555
  template_buffer = backend.arr_to_sharedarr(
543
556
  arr=template, shared_memory_handler=shared_memory_handler
544
557
  )
558
+ template_mask_buffer = backend.arr_to_sharedarr(
559
+ arr=template_mask, shared_memory_handler=shared_memory_handler
560
+ )
545
561
  target_ft_buffer = backend.arr_to_sharedarr(
546
562
  arr=ft_target, shared_memory_handler=shared_memory_handler
547
563
  )
@@ -553,6 +569,7 @@ def flcSphericalMask_setup(
553
569
  )
554
570
 
555
571
  template_tuple = (template_buffer, template.shape, real_dtype)
572
+ template_mask_tuple = (template_mask_buffer, template.shape, real_dtype)
556
573
  target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
557
574
 
558
575
  inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
@@ -560,6 +577,7 @@ def flcSphericalMask_setup(
560
577
 
561
578
  ret = {
562
579
  "template": template_tuple,
580
+ "template_mask": template_mask_tuple,
563
581
  "ft_target": target_ft_tuple,
564
582
  "inv_denominator": inv_denominator_tuple,
565
583
  "numerator2": numerator2_tuple,
@@ -717,10 +735,6 @@ def corr_scoring(
717
735
  datatype.
718
736
  numerator2 : Tuple[type, Tuple[int], type]
719
737
  Tuple containing a pointer to the numerator2 data, its shape, and its datatype.
720
- targetshape : Tuple[int]
721
- The shape of the target.
722
- templateshape : Tuple[int]
723
- The shape of the template.
724
738
  fast_shape : Tuple[int]
725
739
  The shape for fast Fourier transform.
726
740
  fast_ft_shape : Tuple[int]
@@ -739,8 +753,6 @@ def corr_scoring(
739
753
  instantiable.
740
754
  interpolation_order : int
741
755
  The order of interpolation to be used while rotating the template.
742
- convolution_mode : str, optional
743
- Mode to use for convolution, default is "full".
744
756
  **kwargs :
745
757
  Additional arguments to be passed to the function.
746
758
 
@@ -756,31 +768,23 @@ def corr_scoring(
756
768
  :py:meth:`cam_setup`
757
769
  :py:meth:`flcSphericalMask_setup`
758
770
  """
759
- template_buffer, template_shape, template_dtype = template
760
- ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
761
- inv_denominator_buffer, inv_denominator_pointer_shape, _ = inv_denominator
762
- numerator2_buffer, numerator2_shape, _ = numerator2
763
- filter_buffer, filter_shape, filter_dtype = template_filter
764
-
771
+ callback = callback_class
765
772
  if callback_class is not None and isinstance(callback_class, type):
766
773
  callback = callback_class(**callback_class_args)
767
- elif not isinstance(callback_class, type):
768
- callback = callback_class
769
774
 
770
- # Retrieve objects from shared memory
771
- template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
772
- ft_target = backend.sharedarr_to_arr(
773
- ft_target_shape, ft_target_dtype, ft_target_buffer
774
- )
775
- inv_denominator = backend.sharedarr_to_arr(
776
- inv_denominator_pointer_shape, template_dtype, inv_denominator_buffer
777
- )
778
- numerator2 = backend.sharedarr_to_arr(
779
- numerator2_shape, template_dtype, numerator2_buffer
780
- )
781
- template_filter = backend.sharedarr_to_arr(
782
- filter_shape, filter_dtype, filter_buffer
783
- )
775
+ template_buffer, template_shape, template_dtype = template
776
+ template = backend.sharedarr_to_arr(template_buffer, template_shape, template_dtype)
777
+ ft_target = backend.sharedarr_to_arr(*ft_target)
778
+ inv_denominator = backend.sharedarr_to_arr(*inv_denominator)
779
+ numerator2 = backend.sharedarr_to_arr(*numerator2)
780
+ template_filter = backend.sharedarr_to_arr(*template_filter)
781
+
782
+ norm_template, template_mask, mask_sum = False, 1, 1
783
+ if "template_mask" in kwargs:
784
+ template_mask = backend.sharedarr_to_arr(
785
+ kwargs["template_mask"][0], template_shape, template_dtype
786
+ )
787
+ norm_template, mask_sum = True, backend.sum(template_mask)
784
788
 
785
789
  arr = backend.preallocate_array(fast_shape, real_dtype)
786
790
  ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
@@ -799,17 +803,16 @@ def corr_scoring(
799
803
  norm_denominator = (backend.sum(inv_denominator) != 1) & (
800
804
  backend.size(inv_denominator) != 1
801
805
  )
802
- filter_template = backend.size(template_filter) != 0
803
806
 
807
+ norm_template = conditional_execute(normalize_under_mask, norm_template)
808
+ callback_func = conditional_execute(callback, callback_class is not None)
804
809
  norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
805
810
  norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
806
- template_filter_func = conditional_execute(backend.multiply, filter_template)
807
-
808
- axis = tuple(range(arr.ndim))
809
- fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
810
- fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
811
+ template_filter_func = conditional_execute(
812
+ apply_filter, backend.size(template_filter) != 1
813
+ )
811
814
 
812
- template_sum = backend.sum(template)
815
+ unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
813
816
  for index in range(rotations.shape[0]):
814
817
  rotation = rotations[index]
815
818
  backend.fill(arr, 0)
@@ -820,33 +823,24 @@ def corr_scoring(
820
823
  use_geometric_center=False,
821
824
  order=interpolation_order,
822
825
  )
823
- rotation_norm = template_sum / backend.sum(arr)
824
- backend.multiply(arr, rotation_norm, out=arr)
825
826
 
826
- rfftn(arr, ft_temp)
827
- template_filter_func(ft_temp, template_filter, out=ft_temp)
827
+ norm_template(arr[unpadded_slice], template_mask, mask_sum)
828
828
 
829
+ rfftn(arr, ft_temp)
830
+ template_filter_func(ft_template=ft_temp, template_filter=template_filter)
829
831
  backend.multiply(ft_target, ft_temp, out=ft_temp)
830
832
  irfftn(ft_temp, arr)
831
833
 
832
834
  norm_func_numerator(arr, numerator2, out=arr)
833
835
  norm_func_denominator(arr, inv_denominator, out=arr)
834
836
 
835
- if fourier_shift_scores:
836
- arr = backend.roll(arr, shift=fourier_shift, axis=axis)
837
-
838
- score = apply_convolution_mode(
839
- arr, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
837
+ callback_func(
838
+ arr,
839
+ rotation_matrix=rotation,
840
+ rotation_index=index,
841
+ **callback_class_args,
840
842
  )
841
843
 
842
- if callback_class is not None:
843
- callback(
844
- score,
845
- rotation_matrix=rotation,
846
- rotation_index=index,
847
- **callback_class_args,
848
- )
849
-
850
844
  return callback
851
845
 
852
846
 
@@ -894,32 +888,15 @@ def flc_scoring(
894
888
  .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
895
889
  and F. Förster, J. Struct. Biol. 178, 177 (2012).
896
890
  """
897
- template_buffer, template_shape, template_dtype = template
898
- template_mask_buffer, *_ = template_mask
899
- filter_buffer, filter_shape, filter_dtype = template_filter
900
-
901
- ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
902
- ft_target2_buffer, *_ = ft_target2
903
-
891
+ callback = callback_class
904
892
  if callback_class is not None and isinstance(callback_class, type):
905
893
  callback = callback_class(**callback_class_args)
906
- elif not isinstance(callback_class, type):
907
- callback = callback_class
908
894
 
909
- # Retrieve objects from shared memory
910
- template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
911
- template_mask = backend.sharedarr_to_arr(
912
- template_shape, template_dtype, template_mask_buffer
913
- )
914
- ft_target = backend.sharedarr_to_arr(
915
- ft_target_shape, ft_target_dtype, ft_target_buffer
916
- )
917
- ft_target2 = backend.sharedarr_to_arr(
918
- ft_target_shape, ft_target_dtype, ft_target2_buffer
919
- )
920
- template_filter = backend.sharedarr_to_arr(
921
- filter_shape, filter_dtype, filter_buffer
922
- )
895
+ template = backend.sharedarr_to_arr(*template)
896
+ template_mask = backend.sharedarr_to_arr(*template_mask)
897
+ ft_target = backend.sharedarr_to_arr(*ft_target)
898
+ ft_target2 = backend.sharedarr_to_arr(*ft_target2)
899
+ template_filter = backend.sharedarr_to_arr(*template_filter)
923
900
 
924
901
  arr = backend.preallocate_array(fast_shape, real_dtype)
925
902
  temp = backend.preallocate_array(fast_shape, real_dtype)
@@ -938,12 +915,10 @@ def flc_scoring(
938
915
  temp_fft=ft_temp,
939
916
  )
940
917
  eps = backend.eps(real_dtype)
941
- filter_template = backend.size(template_filter) != 0
942
- template_filter_func = conditional_execute(backend.multiply, filter_template)
943
-
944
- axis = tuple(range(arr.ndim))
945
- fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
946
- fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
918
+ template_filter_func = conditional_execute(
919
+ apply_filter, backend.size(template_filter) != 1
920
+ )
921
+ callback_func = conditional_execute(callback, callback_class is not None)
947
922
 
948
923
  unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
949
924
  for index in range(rotations.shape[0]):
@@ -962,7 +937,7 @@ def flc_scoring(
962
937
  # Given the amount of FFTs, might aswell normalize properly
963
938
  n_observations = backend.sum(temp)
964
939
 
965
- _normalize_under_mask(
940
+ normalize_under_mask(
966
941
  template=arr[unpadded_slice],
967
942
  mask=temp[unpadded_slice],
968
943
  mask_intensity=n_observations,
@@ -985,7 +960,7 @@ def flc_scoring(
985
960
  backend.multiply(temp, n_observations, out=temp)
986
961
 
987
962
  rfftn(arr, ft_temp)
988
- template_filter_func(ft_temp, template_filter, out=ft_temp)
963
+ template_filter_func(ft_template=ft_temp, template_filter=template_filter)
989
964
  backend.multiply(ft_target, ft_temp, out=ft_temp)
990
965
  irfftn(ft_temp, arr)
991
966
 
@@ -994,23 +969,161 @@ def flc_scoring(
994
969
  backend.fill(temp2, 0)
995
970
  temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
996
971
 
997
- convolution_mode = kwargs.get("convolution_mode", "full")
972
+ callback_func(
973
+ temp2,
974
+ rotation_matrix=rotation,
975
+ rotation_index=index,
976
+ **callback_class_args,
977
+ )
978
+
979
+ return callback
998
980
 
999
- if fourier_shift_scores:
1000
- temp2 = backend.roll(temp2, shift=fourier_shift, axis=axis)
1001
981
 
1002
- score = apply_convolution_mode(
1003
- temp2, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
982
+ def flc_scoring2(
983
+ template: Tuple[type, Tuple[int], type],
984
+ template_mask: Tuple[type, Tuple[int], type],
985
+ ft_target: Tuple[type, Tuple[int], type],
986
+ ft_target2: Tuple[type, Tuple[int], type],
987
+ template_filter: Tuple[type, Tuple[int], type],
988
+ targetshape: Tuple[int],
989
+ templateshape: Tuple[int],
990
+ fast_shape: Tuple[int],
991
+ fast_ft_shape: Tuple[int],
992
+ rotations: NDArray,
993
+ real_dtype: type,
994
+ complex_dtype: type,
995
+ callback_class: CallbackClass,
996
+ callback_class_args: Dict,
997
+ interpolation_order: int,
998
+ **kwargs,
999
+ ) -> CallbackClass:
1000
+ """
1001
+ Computes a normalized cross-correlation score of a target f a template g
1002
+ and a mask m:
1003
+
1004
+ .. math::
1005
+
1006
+ \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
1007
+ {N_m * \\sqrt{
1008
+ \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
1009
+ }
1010
+
1011
+ Where:
1012
+
1013
+ .. math::
1014
+
1015
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
1016
+
1017
+ and Nm is the number of voxels within the template mask m.
1018
+
1019
+ References
1020
+ ----------
1021
+ .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
1022
+ Microsc. Microanal. 26, 2516 (2020)
1023
+ .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
1024
+ and F. Förster, J. Struct. Biol. 178, 177 (2012).
1025
+ """
1026
+ callback = callback_class
1027
+ if callback_class is not None and isinstance(callback_class, type):
1028
+ callback = callback_class(**callback_class_args)
1029
+
1030
+ # Retrieve objects from shared memory
1031
+ template = backend.sharedarr_to_arr(*template)
1032
+ template_mask = backend.sharedarr_to_arr(*template_mask)
1033
+ ft_target = backend.sharedarr_to_arr(*ft_target)
1034
+ ft_target2 = backend.sharedarr_to_arr(*ft_target2)
1035
+ template_filter = backend.sharedarr_to_arr(*template_filter)
1036
+
1037
+ arr = backend.preallocate_array(fast_shape, real_dtype)
1038
+ temp = backend.preallocate_array(fast_shape, real_dtype)
1039
+ temp2 = backend.preallocate_array(fast_shape, real_dtype)
1040
+
1041
+ ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
1042
+ ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
1043
+
1044
+ eps = backend.eps(real_dtype)
1045
+ template_filter_func = conditional_execute(
1046
+ apply_filter, backend.size(template_filter) != 1
1047
+ )
1048
+ callback_func = conditional_execute(callback, callback_class is not None)
1049
+
1050
+ squeeze_axis = tuple(i for i, x in enumerate(template.shape) if x == 1)
1051
+ squeeze = tuple(
1052
+ slice(0, stop) if i not in squeeze_axis else 0
1053
+ for i, stop in enumerate(template.shape)
1054
+ )
1055
+ squeeze_fast = tuple(
1056
+ slice(0, stop) if i not in squeeze_axis else 0
1057
+ for i, stop in enumerate(fast_shape)
1058
+ )
1059
+ squeeze_fast_ft = tuple(
1060
+ slice(0, stop) if i not in squeeze_axis else 0
1061
+ for i, stop in enumerate(fast_ft_shape)
1062
+ )
1063
+
1064
+ rfftn, irfftn = backend.build_fft(
1065
+ fast_shape=temp[squeeze_fast].shape,
1066
+ fast_ft_shape=fast_ft_shape,
1067
+ real_dtype=real_dtype,
1068
+ complex_dtype=complex_dtype,
1069
+ fftargs=kwargs.get("fftargs", {}),
1070
+ inverse_fast_shape=fast_shape,
1071
+ temp_real=arr[squeeze_fast],
1072
+ temp_fft=ft_temp,
1073
+ )
1074
+ for index in range(rotations.shape[0]):
1075
+ rotation = rotations[index]
1076
+ backend.fill(arr, 0)
1077
+ backend.fill(temp, 0)
1078
+ backend.rotate_array(
1079
+ arr=template[squeeze],
1080
+ arr_mask=template_mask[squeeze],
1081
+ rotation_matrix=rotation,
1082
+ out=arr[squeeze],
1083
+ out_mask=temp[squeeze],
1084
+ use_geometric_center=False,
1085
+ order=interpolation_order,
1086
+ )
1087
+ # Given the amount of FFTs, might aswell normalize properly
1088
+ n_observations = backend.sum(temp)
1089
+
1090
+ normalize_under_mask(
1091
+ template=arr[squeeze],
1092
+ mask=temp[squeeze],
1093
+ mask_intensity=n_observations,
1004
1094
  )
1095
+ rfftn(temp[squeeze_fast], ft_temp[squeeze_fast_ft])
1005
1096
 
1006
- if callback_class is not None:
1007
- callback(
1008
- score,
1009
- rotation_matrix=rotation,
1010
- rotation_index=index,
1011
- **callback_class_args,
1012
- )
1097
+ backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
1098
+ irfftn(ft_denom, temp)
1099
+ backend.divide(temp, n_observations, out=temp)
1100
+ backend.square(temp, out=temp)
1013
1101
 
1102
+ backend.multiply(ft_target2, ft_temp[squeeze_fast_ft], out=ft_denom)
1103
+ irfftn(ft_denom, temp2)
1104
+ backend.divide(temp2, n_observations, out=temp2)
1105
+
1106
+ backend.subtract(temp2, temp, out=temp)
1107
+ backend.maximum(temp, 0.0, out=temp)
1108
+ backend.sqrt(temp, out=temp)
1109
+ backend.multiply(temp, n_observations, out=temp)
1110
+
1111
+ rfftn(arr[squeeze_fast], ft_temp[squeeze_fast_ft])
1112
+ template_filter_func(ft_template=ft_temp, template_filter=template_filter)
1113
+ backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
1114
+ irfftn(ft_denom, arr)
1115
+
1116
+ tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
1117
+ nonzero_indices = temp > tol
1118
+ backend.fill(temp2, 0)
1119
+ temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
1120
+
1121
+ callback_func(
1122
+ temp2,
1123
+ rotation_matrix=rotation,
1124
+ rotation_index=index,
1125
+ **callback_class_args,
1126
+ )
1014
1127
  return callback
1015
1128
 
1016
1129
 
@@ -1064,35 +1177,18 @@ def mcc_scoring(
1064
1177
  --------
1065
1178
  :py:class:`tme.matching_optimization.MaskedCrossCorrelation`
1066
1179
  """
1067
- template_buffer, template_shape, template_dtype = template
1068
- ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
1069
- ft_target2_buffer, ft_target_shape, ft_target_dtype = ft_target2
1070
- template_mask_buffer, _, _ = template
1071
- ft_target_mask_buffer, _, _ = ft_target
1072
- filter_buffer, filter_shape, filter_dtype = template_filter
1073
-
1180
+ callback = callback_class
1074
1181
  if callback_class is not None and isinstance(callback_class, type):
1075
1182
  callback = callback_class(**callback_class_args)
1076
- elif not isinstance(callback_class, type):
1077
- callback = callback_class
1078
1183
 
1079
1184
  # Retrieve objects from shared memory
1080
- template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
1081
- target_ft = backend.sharedarr_to_arr(
1082
- ft_target_shape, ft_target_dtype, ft_target_buffer
1083
- )
1084
- target_ft2 = backend.sharedarr_to_arr(
1085
- ft_target_shape, ft_target_dtype, ft_target2_buffer
1086
- )
1087
- template_mask = backend.sharedarr_to_arr(
1088
- template_shape, template_dtype, template_mask_buffer
1089
- )
1090
- target_mask_ft = backend.sharedarr_to_arr(
1091
- ft_target_shape, ft_target_dtype, ft_target_mask_buffer
1092
- )
1093
- template_filter = backend.sharedarr_to_arr(
1094
- filter_shape, filter_dtype, filter_buffer
1095
- )
1185
+ template_buffer, template_shape, template_dtype = template
1186
+ template = backend.sharedarr_to_arr(*template)
1187
+ target_ft = backend.sharedarr_to_arr(*ft_target)
1188
+ target_ft2 = backend.sharedarr_to_arr(*ft_target2)
1189
+ template_mask = backend.sharedarr_to_arr(*template_mask)
1190
+ target_mask_ft = backend.sharedarr_to_arr(*ft_target_mask)
1191
+ template_filter = backend.sharedarr_to_arr(*template_filter)
1096
1192
 
1097
1193
  axes = tuple(range(template.ndim))
1098
1194
  eps = backend.eps(real_dtype)
@@ -1117,14 +1213,10 @@ def mcc_scoring(
1117
1213
  temp_fft=temp_ft,
1118
1214
  )
1119
1215
 
1120
- filter_template = backend.size(template_filter) != 0
1121
- template_filter_func = conditional_execute(backend.multiply, filter_template)
1122
-
1123
- axis = tuple(range(template.ndim))
1124
- fourier_shift = callback_class_args.get(
1125
- "fourier_shift", backend.zeros(template.ndim)
1216
+ template_filter_func = conditional_execute(
1217
+ apply_filter, backend.size(template_filter) != 1
1126
1218
  )
1127
- fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
1219
+ callback_func = conditional_execute(callback, callback_class is not None)
1128
1220
 
1129
1221
  # Calculate scores across all rotations
1130
1222
  for index in range(rotations.shape[0]):
@@ -1146,7 +1238,7 @@ def mcc_scoring(
1146
1238
 
1147
1239
  # template_rot_ft
1148
1240
  rfftn(template_rot, temp_ft)
1149
- template_filter_func(temp_ft, template_filter, out=temp_ft)
1241
+ template_filter_func(ft_template=temp_ft, template_filter=template_filter)
1150
1242
  irfftn(target_mask_ft * temp_ft, temp2)
1151
1243
  irfftn(target_ft * temp_ft, numerator)
1152
1244
 
@@ -1202,25 +1294,68 @@ def mcc_scoring(
1202
1294
  mask_overlap, axis=axes, keepdims=True
1203
1295
  )
1204
1296
  temp[mask_overlap < number_px_threshold] = 0.0
1205
- convolution_mode = kwargs.get("convolution_mode", "full")
1206
1297
 
1207
- if fourier_shift_scores:
1208
- temp = backend.roll(temp, shift=fourier_shift, axis=axis)
1209
-
1210
- score = apply_convolution_mode(
1211
- temp, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
1298
+ callback_func(
1299
+ temp,
1300
+ rotation_matrix=rotation,
1301
+ rotation_index=index,
1302
+ **callback_class_args,
1212
1303
  )
1213
- if callback_class is not None:
1214
- callback(
1215
- score,
1216
- rotation_matrix=rotation,
1217
- rotation_index=index,
1218
- **callback_class_args,
1219
- )
1220
1304
 
1221
1305
  return callback
1222
1306
 
1223
1307
 
1308
+ def _setup_template_filter_apply_target_filter(
1309
+ matching_data: MatchingData,
1310
+ rfftn: Callable,
1311
+ irfftn: Callable,
1312
+ fast_shape: Tuple[int],
1313
+ fast_ft_shape: Tuple[int],
1314
+ ):
1315
+ filter_template = isinstance(matching_data.template_filter, Compose)
1316
+ filter_target = isinstance(matching_data.target_filter, Compose)
1317
+
1318
+ template_filter = backend.full(
1319
+ shape=(1,), fill_value=1, dtype=backend._default_dtype
1320
+ )
1321
+
1322
+ if not filter_template and not filter_target:
1323
+ return template_filter
1324
+
1325
+ target_temp = backend.astype(
1326
+ backend.topleft_pad(matching_data.target, fast_shape), backend._default_dtype
1327
+ )
1328
+ target_temp_ft = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
1329
+ rfftn(target_temp, target_temp_ft)
1330
+ if isinstance(matching_data.template_filter, Compose):
1331
+ template_filter = matching_data.template_filter(
1332
+ shape=fast_shape,
1333
+ return_real_fourier=True,
1334
+ shape_is_real_fourier=False,
1335
+ data_rfft=target_temp_ft,
1336
+ )
1337
+ template_filter = template_filter["data"]
1338
+ template_filter[tuple(0 for _ in range(template_filter.ndim))] = 0
1339
+
1340
+ if isinstance(matching_data.target_filter, Compose):
1341
+ target_filter = matching_data.target_filter(
1342
+ shape=fast_shape,
1343
+ return_real_fourier=True,
1344
+ shape_is_real_fourier=False,
1345
+ data_rfft=target_temp_ft,
1346
+ weight_type=None,
1347
+ )
1348
+ target_filter = target_filter["data"]
1349
+ backend.multiply(target_temp_ft, target_filter, out=target_temp_ft)
1350
+
1351
+ irfftn(target_temp_ft, target_temp)
1352
+ matching_data._target = backend.topleft_pad(
1353
+ target_temp, matching_data.target.shape
1354
+ )
1355
+
1356
+ return template_filter
1357
+
1358
+
1224
1359
  def device_memory_handler(func: Callable):
1225
1360
  """Decorator function providing SharedMemory Handler."""
1226
1361
 
@@ -1291,18 +1426,17 @@ def scan(
1291
1426
  Tuple
1292
1427
  The merged results from callback_class if provided otherwise None.
1293
1428
  """
1429
+ matching_data.to_backend()
1294
1430
  shape_diff = backend.subtract(
1295
- matching_data._target.shape, matching_data._template.shape
1431
+ matching_data._output_target_shape, matching_data._output_template_shape
1296
1432
  )
1433
+ shape_diff = backend.multiply(shape_diff, ~matching_data._batch_mask)
1297
1434
  if backend.sum(shape_diff < 0) and not pad_fourier:
1298
1435
  warnings.warn(
1299
1436
  "Target is larger than template and Fourier padding is turned off."
1300
- " This can lead to shifted results. You can swap template and target, or"
1301
- " zero-pad the target."
1437
+ " This can lead to shifted results. You can swap template and target, "
1438
+ " zero-pad the target or turn off template centering."
1302
1439
  )
1303
-
1304
- matching_data.to_backend()
1305
-
1306
1440
  fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
1307
1441
  pad_fourier=pad_fourier
1308
1442
  )
@@ -1315,6 +1449,15 @@ def scan(
1315
1449
  complex_dtype=matching_data._complex_dtype,
1316
1450
  fftargs=fftargs,
1317
1451
  )
1452
+
1453
+ template_filter = _setup_template_filter_apply_target_filter(
1454
+ matching_data=matching_data,
1455
+ rfftn=rfftn,
1456
+ irfftn=irfftn,
1457
+ fast_shape=fast_shape,
1458
+ fast_ft_shape=fast_ft_shape,
1459
+ )
1460
+
1318
1461
  setup = matching_setup(
1319
1462
  rfftn=rfftn,
1320
1463
  irfftn=irfftn,
@@ -1332,22 +1475,7 @@ def scan(
1332
1475
  )
1333
1476
  rfftn, irfftn = None, None
1334
1477
 
1335
- template_filter, preprocessor = None, Preprocessor()
1336
- for method, parameters in matching_data.template_filter.items():
1337
- parameters["shape"] = fast_shape
1338
- parameters["omit_negative_frequencies"] = True
1339
- out = preprocessor.apply_method(method=method, parameters=parameters)
1340
- if template_filter is None:
1341
- template_filter = out
1342
- np.multiply(template_filter, out, out=template_filter)
1343
-
1344
- if template_filter is None:
1345
- template_filter = backend.full(
1346
- shape=(1,), fill_value=1, dtype=backend._default_dtype
1347
- )
1348
- else:
1349
- template_filter = backend.to_backend_array(template_filter)
1350
-
1478
+ template_filter = backend.to_backend_array(template_filter)
1351
1479
  template_filter = backend.astype(template_filter, backend._default_dtype)
1352
1480
  template_filter_buffer = backend.arr_to_sharedarr(
1353
1481
  arr=template_filter,
@@ -1369,14 +1497,20 @@ def scan(
1369
1497
  callback_class = setup.pop("callback_class", callback_class)
1370
1498
  callback_class_args = setup.pop("callback_class_args", callback_class_args)
1371
1499
  callback_classes = [callback_class for _ in range(n_callback_classes)]
1500
+
1501
+ convolution_mode = "same"
1502
+ if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
1503
+ convolution_mode = "valid"
1504
+
1505
+ callback_class_args["fourier_shift"] = fourier_shift
1506
+ callback_class_args["convolution_mode"] = convolution_mode
1507
+ callback_class_args["targetshape"] = setup["targetshape"]
1508
+ callback_class_args["templateshape"] = setup["templateshape"]
1509
+
1372
1510
  if callback_class == MaxScoreOverRotations:
1373
- score_space_shape = backend.subtract(
1374
- matching_data.target.shape,
1375
- matching_data._target_pad,
1376
- )
1377
1511
  callback_classes = [
1378
1512
  class_name(
1379
- score_space_shape=score_space_shape,
1513
+ score_space_shape=fast_shape,
1380
1514
  score_space_dtype=matching_data._default_dtype,
1381
1515
  shared_memory_handler=kwargs.get("shared_memory_handler", None),
1382
1516
  rotation_space_dtype=backend._default_dtype_int,
@@ -1416,9 +1550,20 @@ def scan(
1416
1550
  for index, rotation in enumerate(rotation_list)
1417
1551
  )
1418
1552
 
1553
+ callbacks = callbacks[0:n_callback_classes]
1419
1554
  callbacks = [
1420
- tuple(callback)
1421
- for callback in callbacks[0:n_callback_classes]
1555
+ tuple(
1556
+ callback._postprocess(
1557
+ fourier_shift=fourier_shift,
1558
+ convolution_mode=convolution_mode,
1559
+ targetshape=setup["targetshape"],
1560
+ templateshape=setup["templateshape"],
1561
+ shared_memory_handler=kwargs.get("shared_memory_handler", None),
1562
+ )
1563
+ )
1564
+ if hasattr(callback, "_postprocess")
1565
+ else tuple(callback)
1566
+ for callback in callbacks
1422
1567
  if callback is not None
1423
1568
  ]
1424
1569
  backend.free_cache()
@@ -1530,11 +1675,13 @@ def scan_subsets(
1530
1675
  matching_data._target, matching_data._template = None, None
1531
1676
  matching_data._target_mask, matching_data._template_mask = None, None
1532
1677
 
1678
+ candidates = None
1533
1679
  if callback_class is not None:
1534
1680
  candidates = callback_class.merge(
1535
1681
  results, **callback_class_args, inner_merge=False
1536
1682
  )
1537
- return candidates
1683
+
1684
+ return candidates
1538
1685
 
1539
1686
 
1540
1687
  MATCHING_EXHAUSTIVE_REGISTER = {
@@ -1544,6 +1691,7 @@ MATCHING_EXHAUSTIVE_REGISTER = {
1544
1691
  "CAM": (cam_setup, corr_scoring),
1545
1692
  "FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
1546
1693
  "FLC": (flc_setup, flc_scoring),
1694
+ "FLC2": (flc_setup, flc_scoring2),
1547
1695
  "MCC": (mcc_setup, mcc_scoring),
1548
1696
  }
1549
1697