pytme 0.2.0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/match_template.py +183 -69
  2. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
  3. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
  4. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/METADATA +1 -1
  5. pytme-0.2.1.dist-info/RECORD +73 -0
  6. scripts/extract_candidates.py +117 -85
  7. scripts/match_template.py +183 -69
  8. scripts/match_template_filters.py +193 -71
  9. scripts/postprocess.py +107 -49
  10. scripts/preprocessor_gui.py +4 -1
  11. scripts/refine_matches.py +364 -160
  12. tme/__version__.py +1 -1
  13. tme/analyzer.py +259 -117
  14. tme/backends/__init__.py +1 -0
  15. tme/backends/cupy_backend.py +20 -13
  16. tme/backends/jax_backend.py +218 -0
  17. tme/backends/matching_backend.py +25 -10
  18. tme/backends/mlx_backend.py +13 -9
  19. tme/backends/npfftw_backend.py +20 -8
  20. tme/backends/pytorch_backend.py +20 -9
  21. tme/density.py +79 -60
  22. tme/extensions.cpython-311-darwin.so +0 -0
  23. tme/matching_data.py +85 -61
  24. tme/matching_exhaustive.py +222 -129
  25. tme/matching_optimization.py +117 -76
  26. tme/orientations.py +175 -55
  27. tme/preprocessing/_utils.py +17 -5
  28. tme/preprocessing/composable_filter.py +2 -1
  29. tme/preprocessing/compose.py +1 -2
  30. tme/preprocessing/frequency_filters.py +97 -41
  31. tme/preprocessing/tilt_series.py +137 -87
  32. tme/preprocessor.py +3 -0
  33. tme/structure.py +4 -1
  34. pytme-0.2.0.dist-info/RECORD +0 -72
  35. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
  36. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
  37. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
  38. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
  39. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
  40. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
@@ -7,7 +7,6 @@
7
7
  import os
8
8
  import sys
9
9
  import warnings
10
- from copy import deepcopy
11
10
  from itertools import product
12
11
  from typing import Callable, Tuple, Dict
13
12
  from functools import wraps
@@ -62,9 +61,6 @@ def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> No
62
61
  ----------
63
62
  .. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
64
63
  and F. Förster, J. Struct. Biol. 178, 177 (2012).
65
- .. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
66
- R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
67
- 13375 (2023)
68
64
 
69
65
  Returns
70
66
  -------
@@ -85,6 +81,16 @@ def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> No
85
81
  return None
86
82
 
87
83
 
84
+ def _normalize_under_mask_overflow_safe(
85
+ template: NDArray, mask: NDArray, mask_intensity
86
+ ) -> None:
87
+ _template = backend.astype(template, backend._overflow_safe_dtype)
88
+ _mask = backend.astype(mask, backend._overflow_safe_dtype)
89
+ normalize_under_mask(template=_template, mask=_mask, mask_intensity=mask_intensity)
90
+ template[:] = backend.astype(_template, template.dtype)
91
+ return None
92
+
93
+
88
94
  def apply_filter(ft_template, template_filter):
89
95
  # This is an approximation to applying the mask, irfftn, normalize, rfftn
90
96
  std_before = backend.std(ft_template)
@@ -101,8 +107,6 @@ def cc_setup(
101
107
  target: NDArray,
102
108
  fast_shape: Tuple[int],
103
109
  fast_ft_shape: Tuple[int],
104
- real_dtype: type,
105
- complex_dtype: type,
106
110
  shared_memory_handler: Callable,
107
111
  callback_class: Callable,
108
112
  callback_class_args: Dict,
@@ -122,7 +126,7 @@ def cc_setup(
122
126
  :py:meth:`corr_scoring`
123
127
  :py:class:`tme.matching_optimization.CrossCorrelation`
124
128
  """
125
- target_shape = target.shape
129
+ real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
126
130
  target_pad = backend.topleft_pad(target, fast_shape)
127
131
  target_pad_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
128
132
 
@@ -139,28 +143,26 @@ def cc_setup(
139
143
  arr=backend.preallocate_array(1, real_dtype) + 1,
140
144
  shared_memory_handler=shared_memory_handler,
141
145
  )
142
- numerator2_buffer = backend.arr_to_sharedarr(
146
+ numerator_buffer = backend.arr_to_sharedarr(
143
147
  arr=backend.preallocate_array(1, real_dtype),
144
148
  shared_memory_handler=shared_memory_handler,
145
149
  )
146
150
 
147
151
  target_ft_tuple = (target_ft_out, fast_ft_shape, complex_dtype)
148
- template_tuple = (template_out, template.shape, real_dtype)
152
+ template_tuple = (template_out, template.shape, template.dtype)
149
153
 
150
154
  inv_denominator_tuple = (inv_denominator_buffer, (1,), real_dtype)
151
- numerator2_tuple = (numerator2_buffer, (1,), real_dtype)
155
+ numerator_tuple = (numerator_buffer, (1,), real_dtype)
152
156
 
153
157
  ret = {
154
158
  "template": template_tuple,
155
159
  "ft_target": target_ft_tuple,
156
160
  "inv_denominator": inv_denominator_tuple,
157
- "numerator2": numerator2_tuple,
158
- "targetshape": target_shape,
161
+ "numerator": numerator_tuple,
162
+ "targetshape": target.shape,
159
163
  "templateshape": template.shape,
160
164
  "fast_shape": fast_shape,
161
165
  "fast_ft_shape": fast_ft_shape,
162
- "real_dtype": real_dtype,
163
- "complex_dtype": complex_dtype,
164
166
  "callback_class": callback_class,
165
167
  "callback_class_args": callback_class_args,
166
168
  "use_memmap": kwargs.get("use_memmap", False),
@@ -198,8 +200,6 @@ def corr_setup(
198
200
  target: NDArray,
199
201
  fast_shape: Tuple[int],
200
202
  fast_ft_shape: Tuple[int],
201
- real_dtype: type,
202
- complex_dtype: type,
203
203
  shared_memory_handler: Callable,
204
204
  callback_class: Callable,
205
205
  callback_class_args: Dict,
@@ -231,6 +231,7 @@ def corr_setup(
231
231
  :py:meth:`corr_scoring`
232
232
  :py:class:`tme.matching_optimization.NormalizedCrossCorrelation`.
233
233
  """
234
+ real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
234
235
  target_pad = backend.topleft_pad(target, fast_shape)
235
236
 
236
237
  # The exact composition of the denominator is debatable
@@ -264,14 +265,14 @@ def corr_setup(
264
265
  template_mean = backend.divide(template_mean, n_observations)
265
266
  template_ssd = backend.sum(
266
267
  backend.square(
267
- backend.multiply(backend.multiply(template, template_mean), template_mask)
268
+ backend.multiply(backend.subtract(template, template_mean), template_mask)
268
269
  )
269
270
  )
270
- template_volume = np.prod(template.shape)
271
+ template_volume = np.prod(tuple(int(x) for x in template.shape))
271
272
  backend.multiply(template, template_mask, out=template)
272
273
 
273
- # Final numerator is score - numerator2
274
- numerator2 = backend.multiply(target_window_sum, template_mean)
274
+ # Final numerator is score - numerator
275
+ numerator = backend.multiply(target_window_sum, template_mean)
275
276
 
276
277
  # Compute denominator
277
278
  backend.multiply(target_window_sum, target_window_sum, out=target_window_sum)
@@ -298,29 +299,27 @@ def corr_setup(
298
299
  inv_denominator_buffer = backend.arr_to_sharedarr(
299
300
  arr=inv_denominator, shared_memory_handler=shared_memory_handler
300
301
  )
301
- numerator2_buffer = backend.arr_to_sharedarr(
302
- arr=numerator2, shared_memory_handler=shared_memory_handler
302
+ numerator_buffer = backend.arr_to_sharedarr(
303
+ arr=numerator, shared_memory_handler=shared_memory_handler
303
304
  )
304
305
 
305
- template_tuple = (template_buffer, deepcopy(template.shape), real_dtype)
306
+ template_tuple = (template_buffer, template.shape, template.dtype)
306
307
  target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
307
308
 
308
309
  inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
309
- numerator2_tuple = (numerator2_buffer, fast_shape, real_dtype)
310
+ numerator_tuple = (numerator_buffer, fast_shape, real_dtype)
310
311
 
311
- ft_target, inv_denominator, numerator2 = None, None, None
312
+ ft_target, inv_denominator, numerator = None, None, None
312
313
 
313
314
  ret = {
314
315
  "template": template_tuple,
315
316
  "ft_target": target_ft_tuple,
316
317
  "inv_denominator": inv_denominator_tuple,
317
- "numerator2": numerator2_tuple,
318
- "targetshape": deepcopy(target.shape),
319
- "templateshape": deepcopy(template.shape),
318
+ "numerator": numerator_tuple,
319
+ "targetshape": target.shape,
320
+ "templateshape": template.shape,
320
321
  "fast_shape": fast_shape,
321
322
  "fast_ft_shape": fast_ft_shape,
322
- "real_dtype": real_dtype,
323
- "complex_dtype": complex_dtype,
324
323
  "callback_class": callback_class,
325
324
  "callback_class_args": callback_class_args,
326
325
  "template_mean": kwargs.get("template_mean", template_mean),
@@ -372,8 +371,6 @@ def flc_setup(
372
371
  target: NDArray,
373
372
  fast_shape: Tuple[int],
374
373
  fast_ft_shape: Tuple[int],
375
- real_dtype: type,
376
- complex_dtype: type,
377
374
  shared_memory_handler: Callable,
378
375
  callback_class: Callable,
379
376
  callback_class_args: Dict,
@@ -412,8 +409,8 @@ def flc_setup(
412
409
  target_pad = backend.topleft_pad(target, fast_shape)
413
410
 
414
411
  # Target and squared target window sums
415
- ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
416
- ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
412
+ ft_target = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
413
+ ft_target2 = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
417
414
  rfftn(target_pad, ft_target)
418
415
  backend.square(target_pad, out=target_pad)
419
416
  rfftn(target_pad, ft_target2)
@@ -437,11 +434,15 @@ def flc_setup(
437
434
  arr=template_mask, shared_memory_handler=shared_memory_handler
438
435
  )
439
436
 
440
- template_tuple = (template_buffer, template.shape, real_dtype)
441
- template_mask_tuple = (template_mask_buffer, template_mask.shape, real_dtype)
437
+ template_tuple = (template_buffer, template.shape, template.dtype)
438
+ template_mask_tuple = (
439
+ template_mask_buffer,
440
+ template_mask.shape,
441
+ template_mask.dtype,
442
+ )
442
443
 
443
- target_ft_tuple = (ft_target, fast_ft_shape, complex_dtype)
444
- target_ft2_tuple = (ft_target2, fast_ft_shape, complex_dtype)
444
+ target_ft_tuple = (ft_target, fast_ft_shape, backend._complex_dtype)
445
+ target_ft2_tuple = (ft_target2, fast_ft_shape, backend._complex_dtype)
445
446
 
446
447
  ret = {
447
448
  "template": template_tuple,
@@ -452,8 +453,6 @@ def flc_setup(
452
453
  "templateshape": template.shape,
453
454
  "fast_shape": fast_shape,
454
455
  "fast_ft_shape": fast_ft_shape,
455
- "real_dtype": real_dtype,
456
- "complex_dtype": complex_dtype,
457
456
  "callback_class": callback_class,
458
457
  "callback_class_args": callback_class_args,
459
458
  }
@@ -469,8 +468,6 @@ def flcSphericalMask_setup(
469
468
  target: NDArray,
470
469
  fast_shape: Tuple[int],
471
470
  fast_ft_shape: Tuple[int],
472
- real_dtype: type,
473
- complex_dtype: type,
474
471
  shared_memory_handler: Callable,
475
472
  callback_class: Callable,
476
473
  callback_class_args: Dict,
@@ -507,6 +504,7 @@ def flcSphericalMask_setup(
507
504
  --------
508
505
  :py:meth:`corr_scoring`
509
506
  """
507
+ real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
510
508
  target_pad = backend.topleft_pad(target, fast_shape)
511
509
 
512
510
  # Target and squared target window sums
@@ -516,10 +514,14 @@ def flcSphericalMask_setup(
516
514
 
517
515
  temp = backend.preallocate_array(fast_shape, real_dtype)
518
516
  temp2 = backend.preallocate_array(fast_shape, real_dtype)
519
- numerator2 = backend.preallocate_array(1, real_dtype)
517
+ numerator = backend.preallocate_array(1, real_dtype)
520
518
 
521
- eps = backend.eps(real_dtype)
522
- n_observations = backend.sum(template_mask)
519
+ n_observations, norm_func = backend.sum(template_mask), normalize_under_mask
520
+ if backend.datatype_bytes(template_mask.dtype) == 2:
521
+ norm_func = _normalize_under_mask_overflow_safe
522
+ n_observations = backend.sum(
523
+ backend.astype(template_mask, backend._overflow_safe_dtype)
524
+ )
523
525
 
524
526
  template_mask_pad = backend.topleft_pad(template_mask, fast_shape)
525
527
  rfftn(template_mask_pad, ft_template_mask)
@@ -542,15 +544,11 @@ def flcSphericalMask_setup(
542
544
  backend.sqrt(temp, out=temp)
543
545
  backend.multiply(temp, n_observations, out=temp)
544
546
 
545
- tol = 1e3 * eps * backend.max(backend.abs(temp))
546
- nonzero_indices = temp > tol
547
-
548
547
  backend.fill(temp2, 0)
548
+ nonzero_indices = temp > backend.eps(real_dtype)
549
549
  temp2[nonzero_indices] = 1 / temp[nonzero_indices]
550
550
 
551
- normalize_under_mask(
552
- template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
553
- )
551
+ norm_func(template=template, mask=template_mask, mask_intensity=n_observations)
554
552
 
555
553
  template_buffer = backend.arr_to_sharedarr(
556
554
  arr=template, shared_memory_handler=shared_memory_handler
@@ -564,29 +562,27 @@ def flcSphericalMask_setup(
564
562
  inv_denominator_buffer = backend.arr_to_sharedarr(
565
563
  arr=temp2, shared_memory_handler=shared_memory_handler
566
564
  )
567
- numerator2_buffer = backend.arr_to_sharedarr(
568
- arr=numerator2, shared_memory_handler=shared_memory_handler
565
+ numerator_buffer = backend.arr_to_sharedarr(
566
+ arr=numerator, shared_memory_handler=shared_memory_handler
569
567
  )
570
568
 
571
- template_tuple = (template_buffer, template.shape, real_dtype)
572
- template_mask_tuple = (template_mask_buffer, template.shape, real_dtype)
569
+ template_tuple = (template_buffer, template.shape, template.dtype)
570
+ template_mask_tuple = (template_mask_buffer, template.shape, template_mask.dtype)
573
571
  target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
574
572
 
575
573
  inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
576
- numerator2_tuple = (numerator2_buffer, (1,), real_dtype)
574
+ numerator_tuple = (numerator_buffer, (1,), real_dtype)
577
575
 
578
576
  ret = {
579
577
  "template": template_tuple,
580
578
  "template_mask": template_mask_tuple,
581
579
  "ft_target": target_ft_tuple,
582
580
  "inv_denominator": inv_denominator_tuple,
583
- "numerator2": numerator2_tuple,
581
+ "numerator": numerator_tuple,
584
582
  "targetshape": target.shape,
585
583
  "templateshape": template.shape,
586
584
  "fast_shape": fast_shape,
587
585
  "fast_ft_shape": fast_ft_shape,
588
- "real_dtype": real_dtype,
589
- "complex_dtype": complex_dtype,
590
586
  "callback_class": callback_class,
591
587
  "callback_class_args": callback_class_args,
592
588
  }
@@ -603,8 +599,6 @@ def mcc_setup(
603
599
  target_mask: NDArray,
604
600
  fast_shape: Tuple[int],
605
601
  fast_ft_shape: Tuple[int],
606
- real_dtype: type,
607
- complex_dtype: type,
608
602
  shared_memory_handler: Callable,
609
603
  callback_class: Callable,
610
604
  callback_class_args: Dict,
@@ -641,6 +635,7 @@ def mcc_setup(
641
635
  :py:meth:`mcc_scoring`
642
636
  :py:class:`tme.matching_optimization.MaskedCrossCorrelation`
643
637
  """
638
+ real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
644
639
  target = backend.multiply(target, target_mask > 0, out=target)
645
640
 
646
641
  target_pad = backend.topleft_pad(target, fast_shape)
@@ -671,8 +666,8 @@ def mcc_setup(
671
666
  arr=template_mask, shared_memory_handler=shared_memory_handler
672
667
  )
673
668
 
674
- template_tuple = (template_buffer, template.shape, real_dtype)
675
- template_mask_tuple = (template_mask_buffer, template.shape, real_dtype)
669
+ template_tuple = (template_buffer, template.shape, template.dtype)
670
+ template_mask_tuple = (template_mask_buffer, template.shape, template_mask.dtype)
676
671
 
677
672
  target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
678
673
  target_ft2_tuple = (target_ft2_buffer, fast_ft_shape, complex_dtype)
@@ -688,8 +683,6 @@ def mcc_setup(
688
683
  "templateshape": template.shape,
689
684
  "fast_shape": fast_shape,
690
685
  "fast_ft_shape": fast_ft_shape,
691
- "real_dtype": real_dtype,
692
- "complex_dtype": complex_dtype,
693
686
  "callback_class": callback_class,
694
687
  "callback_class_args": callback_class_args,
695
688
  }
@@ -701,15 +694,13 @@ def corr_scoring(
701
694
  template: Tuple[type, Tuple[int], type],
702
695
  ft_target: Tuple[type, Tuple[int], type],
703
696
  inv_denominator: Tuple[type, Tuple[int], type],
704
- numerator2: Tuple[type, Tuple[int], type],
697
+ numerator: Tuple[type, Tuple[int], type],
705
698
  template_filter: Tuple[type, Tuple[int], type],
706
699
  targetshape: Tuple[int],
707
700
  templateshape: Tuple[int],
708
701
  fast_shape: Tuple[int],
709
702
  fast_ft_shape: Tuple[int],
710
703
  rotations: NDArray,
711
- real_dtype: type,
712
- complex_dtype: type,
713
704
  callback_class: CallbackClass,
714
705
  callback_class_args: Dict,
715
706
  interpolation_order: int,
@@ -721,7 +712,7 @@ def corr_scoring(
721
712
 
722
713
  .. math::
723
714
 
724
- (CC(f,g) - numerator2) \\cdot inv\\_denominator
715
+ (CC(f,g) - numerator) \\cdot inv\\_denominator
725
716
 
726
717
  Parameters
727
718
  ----------
@@ -733,8 +724,8 @@ def corr_scoring(
733
724
  inv_denominator : Tuple[type, Tuple[int], type]
734
725
  Tuple containing a pointer to the inverse denominator data, its shape, and its
735
726
  datatype.
736
- numerator2 : Tuple[type, Tuple[int], type]
737
- Tuple containing a pointer to the numerator2 data, its shape, and its datatype.
727
+ numerator : Tuple[type, Tuple[int], type]
728
+ Tuple containing a pointer to the numerator data, its shape, and its datatype.
738
729
  fast_shape : Tuple[int]
739
730
  The shape for fast Fourier transform.
740
731
  fast_ft_shape : Tuple[int]
@@ -768,6 +759,8 @@ def corr_scoring(
768
759
  :py:meth:`cam_setup`
769
760
  :py:meth:`flcSphericalMask_setup`
770
761
  """
762
+ real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
763
+
771
764
  callback = callback_class
772
765
  if callback_class is not None and isinstance(callback_class, type):
773
766
  callback = callback_class(**callback_class_args)
@@ -776,15 +769,33 @@ def corr_scoring(
776
769
  template = backend.sharedarr_to_arr(template_buffer, template_shape, template_dtype)
777
770
  ft_target = backend.sharedarr_to_arr(*ft_target)
778
771
  inv_denominator = backend.sharedarr_to_arr(*inv_denominator)
779
- numerator2 = backend.sharedarr_to_arr(*numerator2)
772
+ numerator = backend.sharedarr_to_arr(*numerator)
780
773
  template_filter = backend.sharedarr_to_arr(*template_filter)
781
774
 
775
+ norm_func = normalize_under_mask
782
776
  norm_template, template_mask, mask_sum = False, 1, 1
783
777
  if "template_mask" in kwargs:
784
- template_mask = backend.sharedarr_to_arr(
785
- kwargs["template_mask"][0], template_shape, template_dtype
786
- )
787
- norm_template, mask_sum = True, backend.sum(template_mask)
778
+ norm_template = True
779
+ template_mask = backend.sharedarr_to_arr(*kwargs["template_mask"])
780
+ mask_sum = backend.sum(template_mask)
781
+ if backend.datatype_bytes(template_mask.dtype) == 2:
782
+ norm_func = _normalize_under_mask_overflow_safe
783
+ mask_sum = backend.sum(
784
+ backend.astype(template_mask, backend._overflow_safe_dtype)
785
+ )
786
+
787
+ norm_template = conditional_execute(norm_func, norm_template)
788
+ norm_numerator = (backend.sum(numerator) != 0) & (backend.size(numerator) != 1)
789
+ norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
790
+
791
+ norm_denominator = (backend.sum(inv_denominator) != 1) & (
792
+ backend.size(inv_denominator) != 1
793
+ )
794
+ norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
795
+ callback_func = conditional_execute(callback, callback_class is not None)
796
+ template_filter_func = conditional_execute(
797
+ apply_filter, backend.size(template_filter) != 1
798
+ )
788
799
 
789
800
  arr = backend.preallocate_array(fast_shape, real_dtype)
790
801
  ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
@@ -799,19 +810,6 @@ def corr_scoring(
799
810
  temp_fft=ft_temp,
800
811
  )
801
812
 
802
- norm_numerator = (backend.sum(numerator2) != 0) & (backend.size(numerator2) != 1)
803
- norm_denominator = (backend.sum(inv_denominator) != 1) & (
804
- backend.size(inv_denominator) != 1
805
- )
806
-
807
- norm_template = conditional_execute(normalize_under_mask, norm_template)
808
- callback_func = conditional_execute(callback, callback_class is not None)
809
- norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
810
- norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
811
- template_filter_func = conditional_execute(
812
- apply_filter, backend.size(template_filter) != 1
813
- )
814
-
815
813
  unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
816
814
  for index in range(rotations.shape[0]):
817
815
  rotation = rotations[index]
@@ -820,10 +818,9 @@ def corr_scoring(
820
818
  arr=template,
821
819
  rotation_matrix=rotation,
822
820
  out=arr,
823
- use_geometric_center=False,
821
+ use_geometric_center=True,
824
822
  order=interpolation_order,
825
823
  )
826
-
827
824
  norm_template(arr[unpadded_slice], template_mask, mask_sum)
828
825
 
829
826
  rfftn(arr, ft_temp)
@@ -831,7 +828,7 @@ def corr_scoring(
831
828
  backend.multiply(ft_target, ft_temp, out=ft_temp)
832
829
  irfftn(ft_temp, arr)
833
830
 
834
- norm_func_numerator(arr, numerator2, out=arr)
831
+ norm_func_numerator(arr, numerator, out=arr)
835
832
  norm_func_denominator(arr, inv_denominator, out=arr)
836
833
 
837
834
  callback_func(
@@ -855,8 +852,6 @@ def flc_scoring(
855
852
  fast_shape: Tuple[int],
856
853
  fast_ft_shape: Tuple[int],
857
854
  rotations: NDArray,
858
- real_dtype: type,
859
- complex_dtype: type,
860
855
  callback_class: CallbackClass,
861
856
  callback_class_args: Dict,
862
857
  interpolation_order: int,
@@ -888,6 +883,8 @@ def flc_scoring(
888
883
  .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
889
884
  and F. Förster, J. Struct. Biol. 178, 177 (2012).
890
885
  """
886
+ real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
887
+
891
888
  callback = callback_class
892
889
  if callback_class is not None and isinstance(callback_class, type):
893
890
  callback = callback_class(**callback_class_args)
@@ -931,7 +928,7 @@ def flc_scoring(
931
928
  rotation_matrix=rotation,
932
929
  out=arr,
933
930
  out_mask=temp,
934
- use_geometric_center=False,
931
+ use_geometric_center=True,
935
932
  order=interpolation_order,
936
933
  )
937
934
  # Given the amount of FFTs, might aswell normalize properly
@@ -964,7 +961,8 @@ def flc_scoring(
964
961
  backend.multiply(ft_target, ft_temp, out=ft_temp)
965
962
  irfftn(ft_temp, arr)
966
963
 
967
- tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
964
+ tol = eps
965
+ # tol = 1e3 * eps * backend.max(backend.abs(temp))
968
966
  nonzero_indices = temp > tol
969
967
  backend.fill(temp2, 0)
970
968
  temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
@@ -990,8 +988,6 @@ def flc_scoring2(
990
988
  fast_shape: Tuple[int],
991
989
  fast_ft_shape: Tuple[int],
992
990
  rotations: NDArray,
993
- real_dtype: type,
994
- complex_dtype: type,
995
991
  callback_class: CallbackClass,
996
992
  callback_class_args: Dict,
997
993
  interpolation_order: int,
@@ -1023,6 +1019,7 @@ def flc_scoring2(
1023
1019
  .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
1024
1020
  and F. Förster, J. Struct. Biol. 178, 177 (2012).
1025
1021
  """
1022
+ real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
1026
1023
  callback = callback_class
1027
1024
  if callback_class is not None and isinstance(callback_class, type):
1028
1025
  callback = callback_class(**callback_class_args)
@@ -1081,7 +1078,7 @@ def flc_scoring2(
1081
1078
  rotation_matrix=rotation,
1082
1079
  out=arr[squeeze],
1083
1080
  out_mask=temp[squeeze],
1084
- use_geometric_center=False,
1081
+ use_geometric_center=True,
1085
1082
  order=interpolation_order,
1086
1083
  )
1087
1084
  # Given the amount of FFTs, might aswell normalize properly
@@ -1113,8 +1110,7 @@ def flc_scoring2(
1113
1110
  backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
1114
1111
  irfftn(ft_denom, arr)
1115
1112
 
1116
- tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
1117
- nonzero_indices = temp > tol
1113
+ nonzero_indices = temp > eps
1118
1114
  backend.fill(temp2, 0)
1119
1115
  temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
1120
1116
 
@@ -1139,8 +1135,6 @@ def mcc_scoring(
1139
1135
  fast_shape: Tuple[int],
1140
1136
  fast_ft_shape: Tuple[int],
1141
1137
  rotations: NDArray,
1142
- real_dtype: type,
1143
- complex_dtype: type,
1144
1138
  callback_class: CallbackClass,
1145
1139
  callback_class_args: type,
1146
1140
  interpolation_order: int,
@@ -1177,6 +1171,7 @@ def mcc_scoring(
1177
1171
  --------
1178
1172
  :py:class:`tme.matching_optimization.MaskedCrossCorrelation`
1179
1173
  """
1174
+ real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
1180
1175
  callback = callback_class
1181
1176
  if callback_class is not None and isinstance(callback_class, type):
1182
1177
  callback = callback_class(**callback_class_args)
@@ -1230,7 +1225,7 @@ def mcc_scoring(
1230
1225
  rotation_matrix=rotation,
1231
1226
  out=template_rot,
1232
1227
  out_mask=temp,
1233
- use_geometric_center=False,
1228
+ use_geometric_center=True,
1234
1229
  order=interpolation_order,
1235
1230
  )
1236
1231
 
@@ -1315,17 +1310,24 @@ def _setup_template_filter_apply_target_filter(
1315
1310
  filter_template = isinstance(matching_data.template_filter, Compose)
1316
1311
  filter_target = isinstance(matching_data.target_filter, Compose)
1317
1312
 
1318
- template_filter = backend.full(
1319
- shape=(1,), fill_value=1, dtype=backend._default_dtype
1320
- )
1313
+ template_filter = backend.full(shape=(1,), fill_value=1, dtype=backend._float_dtype)
1321
1314
 
1322
1315
  if not filter_template and not filter_target:
1323
1316
  return template_filter
1324
1317
 
1325
1318
  target_temp = backend.astype(
1326
- backend.topleft_pad(matching_data.target, fast_shape), backend._default_dtype
1319
+ backend.topleft_pad(matching_data.target, fast_shape), backend._float_dtype
1327
1320
  )
1328
1321
  target_temp_ft = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
1322
+
1323
+ filter_shape = backend.multiply(fast_ft_shape, 1 - matching_data._batch_mask)
1324
+ filter_shape[filter_shape == 0] = 1
1325
+ fast_shape = backend.multiply(fast_shape, 1 - matching_data._batch_mask)
1326
+ fast_shape = fast_shape[fast_shape != 0]
1327
+
1328
+ fast_shape = tuple(int(x) for x in fast_shape)
1329
+ filter_shape = tuple(int(x) for x in filter_shape)
1330
+
1329
1331
  rfftn(target_temp, target_temp_ft)
1330
1332
  if isinstance(matching_data.template_filter, Compose):
1331
1333
  template_filter = matching_data.template_filter(
@@ -1333,9 +1335,11 @@ def _setup_template_filter_apply_target_filter(
1333
1335
  return_real_fourier=True,
1334
1336
  shape_is_real_fourier=False,
1335
1337
  data_rfft=target_temp_ft,
1338
+ batch_dimension=matching_data._target_dims,
1336
1339
  )
1337
1340
  template_filter = template_filter["data"]
1338
1341
  template_filter[tuple(0 for _ in range(template_filter.ndim))] = 0
1342
+ template_filter = backend.reshape(template_filter, filter_shape)
1339
1343
 
1340
1344
  if isinstance(matching_data.target_filter, Compose):
1341
1345
  target_filter = matching_data.target_filter(
@@ -1344,8 +1348,10 @@ def _setup_template_filter_apply_target_filter(
1344
1348
  shape_is_real_fourier=False,
1345
1349
  data_rfft=target_temp_ft,
1346
1350
  weight_type=None,
1351
+ batch_dimension=matching_data._target_dims,
1347
1352
  )
1348
1353
  target_filter = target_filter["data"]
1354
+ target_filter = backend.reshape(target_filter, filter_shape)
1349
1355
  backend.multiply(target_temp_ft, target_filter, out=target_temp_ft)
1350
1356
 
1351
1357
  irfftn(target_temp_ft, target_temp)
@@ -1392,8 +1398,7 @@ def scan(
1392
1398
  **kwargs,
1393
1399
  ) -> Tuple:
1394
1400
  """
1395
- Perform template matching between target and template and sample
1396
- different rotations of template.
1401
+ Perform template matching.
1397
1402
 
1398
1403
  Parameters
1399
1404
  ----------
@@ -1417,24 +1422,43 @@ def scan(
1417
1422
  Order of spline interpolation for rotations.
1418
1423
  jobs_per_callback_class : int, optional
1419
1424
  How many jobs should be processed by a single callback_class instance,
1420
- if ones is provided.
1425
+ if one is provided.
1421
1426
  **kwargs : various
1422
- Additional arguments.
1427
+ Additional keyword arguments.
1423
1428
 
1424
1429
  Returns
1425
1430
  -------
1426
1431
  Tuple
1427
1432
  The merged results from callback_class if provided otherwise None.
1433
+
1434
+ Examples
1435
+ --------
1436
+ Schematically, using :py:meth:`scan` is similar to :py:meth:`scan_subsets`,
1437
+ with the distinction that the objects contained in ``matching_data`` are not
1438
+ split and the search is only parallelized over angles.
1439
+ Assuming you have followed the example in :py:meth:`scan_subsets`, :py:meth:`scan`
1440
+ can be invoked like so
1441
+
1442
+ >>> from tme.matching_exhaustive import scan
1443
+ >>> results = scan(
1444
+ >>> matching_data = matching_data,
1445
+ >>> matching_score = matching_score,
1446
+ >>> matching_setup = matching_setup,
1447
+ >>> callback_class = callback_class,
1448
+ >>> callback_class_args = callback_class_args,
1449
+ >>> )
1450
+
1428
1451
  """
1429
1452
  matching_data.to_backend()
1430
1453
  shape_diff = backend.subtract(
1431
1454
  matching_data._output_target_shape, matching_data._output_template_shape
1432
1455
  )
1433
- shape_diff = backend.multiply(shape_diff, ~matching_data._batch_mask)
1456
+ shape_diff = backend.multiply(shape_diff, 1 - matching_data._batch_mask)
1457
+
1434
1458
  if backend.sum(shape_diff < 0) and not pad_fourier:
1435
1459
  warnings.warn(
1436
1460
  "Target is larger than template and Fourier padding is turned off."
1437
- " This can lead to shifted results. You can swap template and target, "
1461
+ " This can lead to shifted results. You can swap template and target,"
1438
1462
  " zero-pad the target or turn off template centering."
1439
1463
  )
1440
1464
  fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
@@ -1445,8 +1469,8 @@ def scan(
1445
1469
  rfftn, irfftn = backend.build_fft(
1446
1470
  fast_shape=fast_shape,
1447
1471
  fast_ft_shape=fast_ft_shape,
1448
- real_dtype=matching_data._default_dtype,
1449
- complex_dtype=matching_data._complex_dtype,
1472
+ real_dtype=backend._float_dtype,
1473
+ complex_dtype=backend._complex_dtype,
1450
1474
  fftargs=fftargs,
1451
1475
  )
1452
1476
 
@@ -1467,8 +1491,6 @@ def scan(
1467
1491
  target_mask=matching_data.target_mask,
1468
1492
  fast_shape=fast_shape,
1469
1493
  fast_ft_shape=fast_ft_shape,
1470
- real_dtype=matching_data._default_dtype,
1471
- complex_dtype=matching_data._complex_dtype,
1472
1494
  callback_class=callback_class,
1473
1495
  callback_class_args=callback_class_args,
1474
1496
  **kwargs,
@@ -1476,7 +1498,7 @@ def scan(
1476
1498
  rfftn, irfftn = None, None
1477
1499
 
1478
1500
  template_filter = backend.to_backend_array(template_filter)
1479
- template_filter = backend.astype(template_filter, backend._default_dtype)
1501
+ template_filter = backend.astype(template_filter, backend._float_dtype)
1480
1502
  template_filter_buffer = backend.arr_to_sharedarr(
1481
1503
  arr=template_filter,
1482
1504
  shared_memory_handler=kwargs.get("shared_memory_handler", None),
@@ -1511,9 +1533,9 @@ def scan(
1511
1533
  callback_classes = [
1512
1534
  class_name(
1513
1535
  score_space_shape=fast_shape,
1514
- score_space_dtype=matching_data._default_dtype,
1536
+ score_space_dtype=backend._float_dtype,
1515
1537
  shared_memory_handler=kwargs.get("shared_memory_handler", None),
1516
- rotation_space_dtype=backend._default_dtype_int,
1538
+ rotation_space_dtype=backend._int_dtype,
1517
1539
  **callback_class_args,
1518
1540
  )
1519
1541
  for class_name in callback_classes
@@ -1570,10 +1592,13 @@ def scan(
1570
1592
 
1571
1593
  merged_callback = None
1572
1594
  if callback_class is not None:
1595
+ score_indices = None
1596
+ if hasattr(matching_data, "indices"):
1597
+ score_indices = matching_data.indices
1573
1598
  merged_callback = callback_class.merge(
1574
1599
  callbacks,
1575
1600
  **callback_class_args,
1576
- score_indices=matching_data.indices,
1601
+ score_indices=score_indices,
1577
1602
  inner_merge=True,
1578
1603
  )
1579
1604
 
@@ -1596,14 +1621,14 @@ def scan_subsets(
1596
1621
  **kwargs,
1597
1622
  ) -> Tuple:
1598
1623
  """
1599
- Wrapper around :py:meth:`scan` that supports template matching on splits
1600
- of template and target.
1624
+ Wrapper around :py:meth:`scan` that supports matching on splits
1625
+ of ``matching_data``.
1601
1626
 
1602
1627
  Parameters
1603
1628
  ----------
1604
- matching_data : MatchingData
1605
- Template matching data.
1606
- matching_func : type
1629
+ matching_data : :py:class:`tme.matching_data.MatchingData`
1630
+ MatchingData instance containing relevant data.
1631
+ matching_setup : type
1607
1632
  Function pointer to setup function.
1608
1633
  matching_score : type
1609
1634
  Function pointer to scoring function.
@@ -1612,11 +1637,15 @@ def scan_subsets(
1612
1637
  callback_class_args : dict, optional
1613
1638
  Arguments passed to the callback_class. Default is an empty dictionary.
1614
1639
  job_schedule : tuple of int, optional
1615
- Schedule of jobs. Default is (1, 1).
1640
+ Job scheduling scheme, default is (1, 1). First value corresponds
1641
+ to the number of splits that are processed in parallel, the second
1642
+ to the number of angles evaluated in parallel on each split.
1616
1643
  target_splits : dict, optional
1617
- Splits for target. Default is an empty dictionary, i.e. no splits
1644
+ Splits for target. Default is an empty dictionary, i.e. no splits.
1645
+ See :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
1618
1646
  template_splits : dict, optional
1619
1647
  Splits for template. Default is an empty dictionary, i.e. no splits.
1648
+ See :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
1620
1649
  pad_target_edges : bool, optional
1621
1650
  Whether to pad the target boundaries by half the template shape
1622
1651
  along each axis.
@@ -1638,10 +1667,74 @@ def scan_subsets(
1638
1667
  -------
1639
1668
  Tuple
1640
1669
  The merged results from callback_class if provided otherwise None.
1670
+
1671
+ Examples
1672
+ --------
1673
+ All data relevant to template matching will be contained in ``matching_data``, which
1674
+ is a :py:class:`tme.matching_data.MatchingData` instance and can be created like so
1675
+
1676
+ >>> import numpy as np
1677
+ >>> from tme.matching_data import MatchingData
1678
+ >>> from tme.matching_utils import get_rotation_matrices
1679
+ >>> target = np.random.rand(50,40,60)
1680
+ >>> template = target[15:25, 10:20, 30:40]
1681
+ >>> matching_data = MatchingData(target, template)
1682
+ >>> matching_data.rotations = get_rotation_matrices(
1683
+ >>> angular_sampling = 60, dim = target.ndim
1684
+ >>> )
1685
+
1686
+ The template matching procedure is determined by ``matching_setup`` and
1687
+ ``matching_score``, which are unique to each score. In the following,
1688
+ we will be using the `FLCSphericalMask` score, which is composed of
1689
+ :py:meth:`flcSphericalMask_setup` and :py:meth:`corr_scoring`
1690
+
1691
+ >>> from tme.matching_exhaustive import MATCHING_EXHAUSTIVE_REGISTER
1692
+ >>> funcs = MATCHING_EXHAUSTIVE_REGISTER.get("FLCSphericalMask")
1693
+ >>> matching_setup, matching_score = funcs
1694
+
1695
+ Computed scores are flexibly analyzed by being passed through an analyzer. In the
1696
+ following, we will use :py:class:`tme.analyzer.MaxScoreOverRotations` to
1697
+ aggregate sores over rotations
1698
+
1699
+ >>> from tme.analyzer import MaxScoreOverRotations
1700
+ >>> callback_class = MaxScoreOverRotations
1701
+ >>> callback_class_args = {"score_threshold" : 0}
1702
+
1703
+ In case the entire template matching problem does not fit into memory, we can
1704
+ determine the splitting procedure. In this case, we halv the first axis of the target
1705
+ once. Splitting and ``job_schedule`` is typically computed using
1706
+ :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
1707
+
1708
+ >>> target_splits = {0 : 1}
1709
+
1710
+ Finally, we can perform template matching. Note that the data
1711
+ contained in ``matching_data`` will be destroyed when running the following
1712
+
1713
+ >>> from tme.matching_exhaustive import scan_subsets
1714
+ >>> results = scan_subsets(
1715
+ >>> matching_data = matching_data,
1716
+ >>> matching_score = matching_score,
1717
+ >>> matching_setup = matching_setup,
1718
+ >>> callback_class = callback_class,
1719
+ >>> callback_class_args = callback_class_args,
1720
+ >>> target_splits = target_splits,
1721
+ >>> )
1722
+
1723
+ The retuned ``results`` tuple contains the output of the chosen analyzer.
1724
+
1725
+ See Also
1726
+ --------
1727
+ :py:meth:`tme.matching_utils.compute_parallelization_schedule`
1641
1728
  """
1642
1729
  target_splits = split_numpy_array_slices(
1643
1730
  matching_data._target.shape, splits=target_splits
1644
1731
  )
1732
+ if (len(target_splits) > 1) and not pad_target_edges:
1733
+ warnings.warn(
1734
+ "Target splitting without padding target edges leads to unreliable "
1735
+ "similarity estimates around the split border."
1736
+ )
1737
+
1645
1738
  template_splits = split_numpy_array_slices(
1646
1739
  matching_data._template.shape, splits=template_splits
1647
1740
  )