pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/match_template.py +473 -140
  2. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
  3. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
  4. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/METADATA +2 -2
  5. pytme-0.2.1.dist-info/RECORD +73 -0
  6. scripts/extract_candidates.py +117 -85
  7. scripts/match_template.py +473 -140
  8. scripts/match_template_filters.py +458 -169
  9. scripts/postprocess.py +107 -49
  10. scripts/preprocessor_gui.py +4 -1
  11. scripts/refine_matches.py +364 -160
  12. tme/__version__.py +1 -1
  13. tme/analyzer.py +278 -148
  14. tme/backends/__init__.py +1 -0
  15. tme/backends/cupy_backend.py +20 -13
  16. tme/backends/jax_backend.py +218 -0
  17. tme/backends/matching_backend.py +25 -10
  18. tme/backends/mlx_backend.py +13 -9
  19. tme/backends/npfftw_backend.py +22 -12
  20. tme/backends/pytorch_backend.py +20 -9
  21. tme/density.py +85 -64
  22. tme/extensions.cpython-311-darwin.so +0 -0
  23. tme/matching_data.py +86 -60
  24. tme/matching_exhaustive.py +245 -166
  25. tme/matching_optimization.py +137 -69
  26. tme/matching_utils.py +1 -1
  27. tme/orientations.py +175 -55
  28. tme/preprocessing/__init__.py +2 -0
  29. tme/preprocessing/_utils.py +188 -0
  30. tme/preprocessing/composable_filter.py +31 -0
  31. tme/preprocessing/compose.py +51 -0
  32. tme/preprocessing/frequency_filters.py +378 -0
  33. tme/preprocessing/tilt_series.py +1017 -0
  34. tme/preprocessor.py +17 -7
  35. tme/structure.py +4 -1
  36. pytme-0.2.0b0.dist-info/RECORD +0 -66
  37. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
  38. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
  39. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
  40. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
  41. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
  42. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
@@ -7,7 +7,6 @@
7
7
  import os
8
8
  import sys
9
9
  import warnings
10
- from copy import deepcopy
11
10
  from itertools import product
12
11
  from typing import Callable, Tuple, Dict
13
12
  from functools import wraps
@@ -19,14 +18,12 @@ from scipy.ndimage import laplace
19
18
 
20
19
  from .analyzer import MaxScoreOverRotations
21
20
  from .matching_utils import (
22
- apply_convolution_mode,
23
21
  handle_traceback,
24
22
  split_numpy_array_slices,
25
23
  conditional_execute,
26
24
  )
27
25
  from .matching_memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
28
- # from .preprocessing import Compose
29
- from .preprocessor import Preprocessor
26
+ from .preprocessing import Compose
30
27
  from .matching_data import MatchingData
31
28
  from .backends import backend
32
29
  from .types import NDArray, CallbackClass
@@ -64,9 +61,6 @@ def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> No
64
61
  ----------
65
62
  .. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
66
63
  and F. Förster, J. Struct. Biol. 178, 177 (2012).
67
- .. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
68
- R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
69
- 13375 (2023)
70
64
 
71
65
  Returns
72
66
  -------
@@ -87,6 +81,16 @@ def normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> No
87
81
  return None
88
82
 
89
83
 
84
+ def _normalize_under_mask_overflow_safe(
85
+ template: NDArray, mask: NDArray, mask_intensity
86
+ ) -> None:
87
+ _template = backend.astype(template, backend._overflow_safe_dtype)
88
+ _mask = backend.astype(mask, backend._overflow_safe_dtype)
89
+ normalize_under_mask(template=_template, mask=_mask, mask_intensity=mask_intensity)
90
+ template[:] = backend.astype(_template, template.dtype)
91
+ return None
92
+
93
+
90
94
  def apply_filter(ft_template, template_filter):
91
95
  # This is an approximation to applying the mask, irfftn, normalize, rfftn
92
96
  std_before = backend.std(ft_template)
@@ -103,8 +107,6 @@ def cc_setup(
103
107
  target: NDArray,
104
108
  fast_shape: Tuple[int],
105
109
  fast_ft_shape: Tuple[int],
106
- real_dtype: type,
107
- complex_dtype: type,
108
110
  shared_memory_handler: Callable,
109
111
  callback_class: Callable,
110
112
  callback_class_args: Dict,
@@ -124,7 +126,7 @@ def cc_setup(
124
126
  :py:meth:`corr_scoring`
125
127
  :py:class:`tme.matching_optimization.CrossCorrelation`
126
128
  """
127
- target_shape = target.shape
129
+ real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
128
130
  target_pad = backend.topleft_pad(target, fast_shape)
129
131
  target_pad_ft = backend.preallocate_array(fast_ft_shape, complex_dtype)
130
132
 
@@ -141,28 +143,26 @@ def cc_setup(
141
143
  arr=backend.preallocate_array(1, real_dtype) + 1,
142
144
  shared_memory_handler=shared_memory_handler,
143
145
  )
144
- numerator2_buffer = backend.arr_to_sharedarr(
146
+ numerator_buffer = backend.arr_to_sharedarr(
145
147
  arr=backend.preallocate_array(1, real_dtype),
146
148
  shared_memory_handler=shared_memory_handler,
147
149
  )
148
150
 
149
151
  target_ft_tuple = (target_ft_out, fast_ft_shape, complex_dtype)
150
- template_tuple = (template_out, template.shape, real_dtype)
152
+ template_tuple = (template_out, template.shape, template.dtype)
151
153
 
152
154
  inv_denominator_tuple = (inv_denominator_buffer, (1,), real_dtype)
153
- numerator2_tuple = (numerator2_buffer, (1,), real_dtype)
155
+ numerator_tuple = (numerator_buffer, (1,), real_dtype)
154
156
 
155
157
  ret = {
156
158
  "template": template_tuple,
157
159
  "ft_target": target_ft_tuple,
158
160
  "inv_denominator": inv_denominator_tuple,
159
- "numerator2": numerator2_tuple,
160
- "targetshape": target_shape,
161
+ "numerator": numerator_tuple,
162
+ "targetshape": target.shape,
161
163
  "templateshape": template.shape,
162
164
  "fast_shape": fast_shape,
163
165
  "fast_ft_shape": fast_ft_shape,
164
- "real_dtype": real_dtype,
165
- "complex_dtype": complex_dtype,
166
166
  "callback_class": callback_class,
167
167
  "callback_class_args": callback_class_args,
168
168
  "use_memmap": kwargs.get("use_memmap", False),
@@ -200,8 +200,6 @@ def corr_setup(
200
200
  target: NDArray,
201
201
  fast_shape: Tuple[int],
202
202
  fast_ft_shape: Tuple[int],
203
- real_dtype: type,
204
- complex_dtype: type,
205
203
  shared_memory_handler: Callable,
206
204
  callback_class: Callable,
207
205
  callback_class_args: Dict,
@@ -233,6 +231,7 @@ def corr_setup(
233
231
  :py:meth:`corr_scoring`
234
232
  :py:class:`tme.matching_optimization.NormalizedCrossCorrelation`.
235
233
  """
234
+ real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
236
235
  target_pad = backend.topleft_pad(target, fast_shape)
237
236
 
238
237
  # The exact composition of the denominator is debatable
@@ -266,14 +265,14 @@ def corr_setup(
266
265
  template_mean = backend.divide(template_mean, n_observations)
267
266
  template_ssd = backend.sum(
268
267
  backend.square(
269
- backend.multiply(backend.multiply(template, template_mean), template_mask)
268
+ backend.multiply(backend.subtract(template, template_mean), template_mask)
270
269
  )
271
270
  )
272
- template_volume = np.prod(template.shape)
271
+ template_volume = np.prod(tuple(int(x) for x in template.shape))
273
272
  backend.multiply(template, template_mask, out=template)
274
273
 
275
- # Final numerator is score - numerator2
276
- numerator2 = backend.multiply(target_window_sum, template_mean)
274
+ # Final numerator is score - numerator
275
+ numerator = backend.multiply(target_window_sum, template_mean)
277
276
 
278
277
  # Compute denominator
279
278
  backend.multiply(target_window_sum, target_window_sum, out=target_window_sum)
@@ -300,29 +299,27 @@ def corr_setup(
300
299
  inv_denominator_buffer = backend.arr_to_sharedarr(
301
300
  arr=inv_denominator, shared_memory_handler=shared_memory_handler
302
301
  )
303
- numerator2_buffer = backend.arr_to_sharedarr(
304
- arr=numerator2, shared_memory_handler=shared_memory_handler
302
+ numerator_buffer = backend.arr_to_sharedarr(
303
+ arr=numerator, shared_memory_handler=shared_memory_handler
305
304
  )
306
305
 
307
- template_tuple = (template_buffer, deepcopy(template.shape), real_dtype)
306
+ template_tuple = (template_buffer, template.shape, template.dtype)
308
307
  target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
309
308
 
310
309
  inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
311
- numerator2_tuple = (numerator2_buffer, fast_shape, real_dtype)
310
+ numerator_tuple = (numerator_buffer, fast_shape, real_dtype)
312
311
 
313
- ft_target, inv_denominator, numerator2 = None, None, None
312
+ ft_target, inv_denominator, numerator = None, None, None
314
313
 
315
314
  ret = {
316
315
  "template": template_tuple,
317
316
  "ft_target": target_ft_tuple,
318
317
  "inv_denominator": inv_denominator_tuple,
319
- "numerator2": numerator2_tuple,
320
- "targetshape": deepcopy(target.shape),
321
- "templateshape": deepcopy(template.shape),
318
+ "numerator": numerator_tuple,
319
+ "targetshape": target.shape,
320
+ "templateshape": template.shape,
322
321
  "fast_shape": fast_shape,
323
322
  "fast_ft_shape": fast_ft_shape,
324
- "real_dtype": real_dtype,
325
- "complex_dtype": complex_dtype,
326
323
  "callback_class": callback_class,
327
324
  "callback_class_args": callback_class_args,
328
325
  "template_mean": kwargs.get("template_mean", template_mean),
@@ -374,8 +371,6 @@ def flc_setup(
374
371
  target: NDArray,
375
372
  fast_shape: Tuple[int],
376
373
  fast_ft_shape: Tuple[int],
377
- real_dtype: type,
378
- complex_dtype: type,
379
374
  shared_memory_handler: Callable,
380
375
  callback_class: Callable,
381
376
  callback_class_args: Dict,
@@ -414,8 +409,8 @@ def flc_setup(
414
409
  target_pad = backend.topleft_pad(target, fast_shape)
415
410
 
416
411
  # Target and squared target window sums
417
- ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
418
- ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
412
+ ft_target = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
413
+ ft_target2 = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
419
414
  rfftn(target_pad, ft_target)
420
415
  backend.square(target_pad, out=target_pad)
421
416
  rfftn(target_pad, ft_target2)
@@ -439,11 +434,15 @@ def flc_setup(
439
434
  arr=template_mask, shared_memory_handler=shared_memory_handler
440
435
  )
441
436
 
442
- template_tuple = (template_buffer, template.shape, real_dtype)
443
- template_mask_tuple = (template_mask_buffer, template_mask.shape, real_dtype)
437
+ template_tuple = (template_buffer, template.shape, template.dtype)
438
+ template_mask_tuple = (
439
+ template_mask_buffer,
440
+ template_mask.shape,
441
+ template_mask.dtype,
442
+ )
444
443
 
445
- target_ft_tuple = (ft_target, fast_ft_shape, complex_dtype)
446
- target_ft2_tuple = (ft_target2, fast_ft_shape, complex_dtype)
444
+ target_ft_tuple = (ft_target, fast_ft_shape, backend._complex_dtype)
445
+ target_ft2_tuple = (ft_target2, fast_ft_shape, backend._complex_dtype)
447
446
 
448
447
  ret = {
449
448
  "template": template_tuple,
@@ -454,8 +453,6 @@ def flc_setup(
454
453
  "templateshape": template.shape,
455
454
  "fast_shape": fast_shape,
456
455
  "fast_ft_shape": fast_ft_shape,
457
- "real_dtype": real_dtype,
458
- "complex_dtype": complex_dtype,
459
456
  "callback_class": callback_class,
460
457
  "callback_class_args": callback_class_args,
461
458
  }
@@ -471,8 +468,6 @@ def flcSphericalMask_setup(
471
468
  target: NDArray,
472
469
  fast_shape: Tuple[int],
473
470
  fast_ft_shape: Tuple[int],
474
- real_dtype: type,
475
- complex_dtype: type,
476
471
  shared_memory_handler: Callable,
477
472
  callback_class: Callable,
478
473
  callback_class_args: Dict,
@@ -509,6 +504,7 @@ def flcSphericalMask_setup(
509
504
  --------
510
505
  :py:meth:`corr_scoring`
511
506
  """
507
+ real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
512
508
  target_pad = backend.topleft_pad(target, fast_shape)
513
509
 
514
510
  # Target and squared target window sums
@@ -518,10 +514,14 @@ def flcSphericalMask_setup(
518
514
 
519
515
  temp = backend.preallocate_array(fast_shape, real_dtype)
520
516
  temp2 = backend.preallocate_array(fast_shape, real_dtype)
521
- numerator2 = backend.preallocate_array(1, real_dtype)
517
+ numerator = backend.preallocate_array(1, real_dtype)
522
518
 
523
- eps = backend.eps(real_dtype)
524
- n_observations = backend.sum(template_mask)
519
+ n_observations, norm_func = backend.sum(template_mask), normalize_under_mask
520
+ if backend.datatype_bytes(template_mask.dtype) == 2:
521
+ norm_func = _normalize_under_mask_overflow_safe
522
+ n_observations = backend.sum(
523
+ backend.astype(template_mask, backend._overflow_safe_dtype)
524
+ )
525
525
 
526
526
  template_mask_pad = backend.topleft_pad(template_mask, fast_shape)
527
527
  rfftn(template_mask_pad, ft_template_mask)
@@ -544,15 +544,11 @@ def flcSphericalMask_setup(
544
544
  backend.sqrt(temp, out=temp)
545
545
  backend.multiply(temp, n_observations, out=temp)
546
546
 
547
- tol = 1e3 * eps * backend.max(backend.abs(temp))
548
- nonzero_indices = temp > tol
549
-
550
547
  backend.fill(temp2, 0)
548
+ nonzero_indices = temp > backend.eps(real_dtype)
551
549
  temp2[nonzero_indices] = 1 / temp[nonzero_indices]
552
550
 
553
- normalize_under_mask(
554
- template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
555
- )
551
+ norm_func(template=template, mask=template_mask, mask_intensity=n_observations)
556
552
 
557
553
  template_buffer = backend.arr_to_sharedarr(
558
554
  arr=template, shared_memory_handler=shared_memory_handler
@@ -566,29 +562,27 @@ def flcSphericalMask_setup(
566
562
  inv_denominator_buffer = backend.arr_to_sharedarr(
567
563
  arr=temp2, shared_memory_handler=shared_memory_handler
568
564
  )
569
- numerator2_buffer = backend.arr_to_sharedarr(
570
- arr=numerator2, shared_memory_handler=shared_memory_handler
565
+ numerator_buffer = backend.arr_to_sharedarr(
566
+ arr=numerator, shared_memory_handler=shared_memory_handler
571
567
  )
572
568
 
573
- template_tuple = (template_buffer, template.shape, real_dtype)
574
- template_mask_tuple = (template_mask_buffer, template.shape, real_dtype)
569
+ template_tuple = (template_buffer, template.shape, template.dtype)
570
+ template_mask_tuple = (template_mask_buffer, template.shape, template_mask.dtype)
575
571
  target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
576
572
 
577
573
  inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
578
- numerator2_tuple = (numerator2_buffer, (1,), real_dtype)
574
+ numerator_tuple = (numerator_buffer, (1,), real_dtype)
579
575
 
580
576
  ret = {
581
577
  "template": template_tuple,
582
578
  "template_mask": template_mask_tuple,
583
579
  "ft_target": target_ft_tuple,
584
580
  "inv_denominator": inv_denominator_tuple,
585
- "numerator2": numerator2_tuple,
581
+ "numerator": numerator_tuple,
586
582
  "targetshape": target.shape,
587
583
  "templateshape": template.shape,
588
584
  "fast_shape": fast_shape,
589
585
  "fast_ft_shape": fast_ft_shape,
590
- "real_dtype": real_dtype,
591
- "complex_dtype": complex_dtype,
592
586
  "callback_class": callback_class,
593
587
  "callback_class_args": callback_class_args,
594
588
  }
@@ -605,8 +599,6 @@ def mcc_setup(
605
599
  target_mask: NDArray,
606
600
  fast_shape: Tuple[int],
607
601
  fast_ft_shape: Tuple[int],
608
- real_dtype: type,
609
- complex_dtype: type,
610
602
  shared_memory_handler: Callable,
611
603
  callback_class: Callable,
612
604
  callback_class_args: Dict,
@@ -643,6 +635,7 @@ def mcc_setup(
643
635
  :py:meth:`mcc_scoring`
644
636
  :py:class:`tme.matching_optimization.MaskedCrossCorrelation`
645
637
  """
638
+ real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
646
639
  target = backend.multiply(target, target_mask > 0, out=target)
647
640
 
648
641
  target_pad = backend.topleft_pad(target, fast_shape)
@@ -673,8 +666,8 @@ def mcc_setup(
673
666
  arr=template_mask, shared_memory_handler=shared_memory_handler
674
667
  )
675
668
 
676
- template_tuple = (template_buffer, template.shape, real_dtype)
677
- template_mask_tuple = (template_mask_buffer, template.shape, real_dtype)
669
+ template_tuple = (template_buffer, template.shape, template.dtype)
670
+ template_mask_tuple = (template_mask_buffer, template.shape, template_mask.dtype)
678
671
 
679
672
  target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
680
673
  target_ft2_tuple = (target_ft2_buffer, fast_ft_shape, complex_dtype)
@@ -690,8 +683,6 @@ def mcc_setup(
690
683
  "templateshape": template.shape,
691
684
  "fast_shape": fast_shape,
692
685
  "fast_ft_shape": fast_ft_shape,
693
- "real_dtype": real_dtype,
694
- "complex_dtype": complex_dtype,
695
686
  "callback_class": callback_class,
696
687
  "callback_class_args": callback_class_args,
697
688
  }
@@ -703,15 +694,13 @@ def corr_scoring(
703
694
  template: Tuple[type, Tuple[int], type],
704
695
  ft_target: Tuple[type, Tuple[int], type],
705
696
  inv_denominator: Tuple[type, Tuple[int], type],
706
- numerator2: Tuple[type, Tuple[int], type],
697
+ numerator: Tuple[type, Tuple[int], type],
707
698
  template_filter: Tuple[type, Tuple[int], type],
708
699
  targetshape: Tuple[int],
709
700
  templateshape: Tuple[int],
710
701
  fast_shape: Tuple[int],
711
702
  fast_ft_shape: Tuple[int],
712
703
  rotations: NDArray,
713
- real_dtype: type,
714
- complex_dtype: type,
715
704
  callback_class: CallbackClass,
716
705
  callback_class_args: Dict,
717
706
  interpolation_order: int,
@@ -723,7 +712,7 @@ def corr_scoring(
723
712
 
724
713
  .. math::
725
714
 
726
- (CC(f,g) - numerator2) \\cdot inv\\_denominator
715
+ (CC(f,g) - numerator) \\cdot inv\\_denominator
727
716
 
728
717
  Parameters
729
718
  ----------
@@ -735,8 +724,8 @@ def corr_scoring(
735
724
  inv_denominator : Tuple[type, Tuple[int], type]
736
725
  Tuple containing a pointer to the inverse denominator data, its shape, and its
737
726
  datatype.
738
- numerator2 : Tuple[type, Tuple[int], type]
739
- Tuple containing a pointer to the numerator2 data, its shape, and its datatype.
727
+ numerator : Tuple[type, Tuple[int], type]
728
+ Tuple containing a pointer to the numerator data, its shape, and its datatype.
740
729
  fast_shape : Tuple[int]
741
730
  The shape for fast Fourier transform.
742
731
  fast_ft_shape : Tuple[int]
@@ -770,6 +759,8 @@ def corr_scoring(
770
759
  :py:meth:`cam_setup`
771
760
  :py:meth:`flcSphericalMask_setup`
772
761
  """
762
+ real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
763
+
773
764
  callback = callback_class
774
765
  if callback_class is not None and isinstance(callback_class, type):
775
766
  callback = callback_class(**callback_class_args)
@@ -778,15 +769,33 @@ def corr_scoring(
778
769
  template = backend.sharedarr_to_arr(template_buffer, template_shape, template_dtype)
779
770
  ft_target = backend.sharedarr_to_arr(*ft_target)
780
771
  inv_denominator = backend.sharedarr_to_arr(*inv_denominator)
781
- numerator2 = backend.sharedarr_to_arr(*numerator2)
772
+ numerator = backend.sharedarr_to_arr(*numerator)
782
773
  template_filter = backend.sharedarr_to_arr(*template_filter)
783
774
 
775
+ norm_func = normalize_under_mask
784
776
  norm_template, template_mask, mask_sum = False, 1, 1
785
777
  if "template_mask" in kwargs:
786
- template_mask = backend.sharedarr_to_arr(
787
- kwargs["template_mask"][0], template_shape, template_dtype
788
- )
789
- norm_template, mask_sum = True, backend.sum(template_mask)
778
+ norm_template = True
779
+ template_mask = backend.sharedarr_to_arr(*kwargs["template_mask"])
780
+ mask_sum = backend.sum(template_mask)
781
+ if backend.datatype_bytes(template_mask.dtype) == 2:
782
+ norm_func = _normalize_under_mask_overflow_safe
783
+ mask_sum = backend.sum(
784
+ backend.astype(template_mask, backend._overflow_safe_dtype)
785
+ )
786
+
787
+ norm_template = conditional_execute(norm_func, norm_template)
788
+ norm_numerator = (backend.sum(numerator) != 0) & (backend.size(numerator) != 1)
789
+ norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
790
+
791
+ norm_denominator = (backend.sum(inv_denominator) != 1) & (
792
+ backend.size(inv_denominator) != 1
793
+ )
794
+ norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
795
+ callback_func = conditional_execute(callback, callback_class is not None)
796
+ template_filter_func = conditional_execute(
797
+ apply_filter, backend.size(template_filter) != 1
798
+ )
790
799
 
791
800
  arr = backend.preallocate_array(fast_shape, real_dtype)
792
801
  ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
@@ -801,19 +810,6 @@ def corr_scoring(
801
810
  temp_fft=ft_temp,
802
811
  )
803
812
 
804
- norm_numerator = (backend.sum(numerator2) != 0) & (backend.size(numerator2) != 1)
805
- norm_denominator = (backend.sum(inv_denominator) != 1) & (
806
- backend.size(inv_denominator) != 1
807
- )
808
-
809
- norm_template = conditional_execute(normalize_under_mask, norm_template)
810
- callback_func = conditional_execute(callback, callback_class is not None)
811
- norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
812
- norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
813
- template_filter_func = conditional_execute(
814
- apply_filter, backend.size(template_filter) != 1
815
- )
816
-
817
813
  unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
818
814
  for index in range(rotations.shape[0]):
819
815
  rotation = rotations[index]
@@ -822,10 +818,9 @@ def corr_scoring(
822
818
  arr=template,
823
819
  rotation_matrix=rotation,
824
820
  out=arr,
825
- use_geometric_center=False,
821
+ use_geometric_center=True,
826
822
  order=interpolation_order,
827
823
  )
828
-
829
824
  norm_template(arr[unpadded_slice], template_mask, mask_sum)
830
825
 
831
826
  rfftn(arr, ft_temp)
@@ -833,7 +828,7 @@ def corr_scoring(
833
828
  backend.multiply(ft_target, ft_temp, out=ft_temp)
834
829
  irfftn(ft_temp, arr)
835
830
 
836
- norm_func_numerator(arr, numerator2, out=arr)
831
+ norm_func_numerator(arr, numerator, out=arr)
837
832
  norm_func_denominator(arr, inv_denominator, out=arr)
838
833
 
839
834
  callback_func(
@@ -857,8 +852,6 @@ def flc_scoring(
857
852
  fast_shape: Tuple[int],
858
853
  fast_ft_shape: Tuple[int],
859
854
  rotations: NDArray,
860
- real_dtype: type,
861
- complex_dtype: type,
862
855
  callback_class: CallbackClass,
863
856
  callback_class_args: Dict,
864
857
  interpolation_order: int,
@@ -890,6 +883,8 @@ def flc_scoring(
890
883
  .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
891
884
  and F. Förster, J. Struct. Biol. 178, 177 (2012).
892
885
  """
886
+ real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
887
+
893
888
  callback = callback_class
894
889
  if callback_class is not None and isinstance(callback_class, type):
895
890
  callback = callback_class(**callback_class_args)
@@ -933,7 +928,7 @@ def flc_scoring(
933
928
  rotation_matrix=rotation,
934
929
  out=arr,
935
930
  out_mask=temp,
936
- use_geometric_center=False,
931
+ use_geometric_center=True,
937
932
  order=interpolation_order,
938
933
  )
939
934
  # Given the amount of FFTs, might aswell normalize properly
@@ -966,7 +961,8 @@ def flc_scoring(
966
961
  backend.multiply(ft_target, ft_temp, out=ft_temp)
967
962
  irfftn(ft_temp, arr)
968
963
 
969
- tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
964
+ tol = eps
965
+ # tol = 1e3 * eps * backend.max(backend.abs(temp))
970
966
  nonzero_indices = temp > tol
971
967
  backend.fill(temp2, 0)
972
968
  temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
@@ -992,8 +988,6 @@ def flc_scoring2(
992
988
  fast_shape: Tuple[int],
993
989
  fast_ft_shape: Tuple[int],
994
990
  rotations: NDArray,
995
- real_dtype: type,
996
- complex_dtype: type,
997
991
  callback_class: CallbackClass,
998
992
  callback_class_args: Dict,
999
993
  interpolation_order: int,
@@ -1025,6 +1019,7 @@ def flc_scoring2(
1025
1019
  .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
1026
1020
  and F. Förster, J. Struct. Biol. 178, 177 (2012).
1027
1021
  """
1022
+ real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
1028
1023
  callback = callback_class
1029
1024
  if callback_class is not None and isinstance(callback_class, type):
1030
1025
  callback = callback_class(**callback_class_args)
@@ -1083,7 +1078,7 @@ def flc_scoring2(
1083
1078
  rotation_matrix=rotation,
1084
1079
  out=arr[squeeze],
1085
1080
  out_mask=temp[squeeze],
1086
- use_geometric_center=False,
1081
+ use_geometric_center=True,
1087
1082
  order=interpolation_order,
1088
1083
  )
1089
1084
  # Given the amount of FFTs, might aswell normalize properly
@@ -1115,8 +1110,7 @@ def flc_scoring2(
1115
1110
  backend.multiply(ft_target, ft_temp[squeeze_fast_ft], out=ft_denom)
1116
1111
  irfftn(ft_denom, arr)
1117
1112
 
1118
- tol = tol = 1e3 * eps * backend.max(backend.abs(temp))
1119
- nonzero_indices = temp > tol
1113
+ nonzero_indices = temp > eps
1120
1114
  backend.fill(temp2, 0)
1121
1115
  temp2[nonzero_indices] = arr[nonzero_indices] / temp[nonzero_indices]
1122
1116
 
@@ -1141,8 +1135,6 @@ def mcc_scoring(
1141
1135
  fast_shape: Tuple[int],
1142
1136
  fast_ft_shape: Tuple[int],
1143
1137
  rotations: NDArray,
1144
- real_dtype: type,
1145
- complex_dtype: type,
1146
1138
  callback_class: CallbackClass,
1147
1139
  callback_class_args: type,
1148
1140
  interpolation_order: int,
@@ -1179,6 +1171,7 @@ def mcc_scoring(
1179
1171
  --------
1180
1172
  :py:class:`tme.matching_optimization.MaskedCrossCorrelation`
1181
1173
  """
1174
+ real_dtype, complex_dtype = backend._float_dtype, backend._complex_dtype
1182
1175
  callback = callback_class
1183
1176
  if callback_class is not None and isinstance(callback_class, type):
1184
1177
  callback = callback_class(**callback_class_args)
@@ -1232,7 +1225,7 @@ def mcc_scoring(
1232
1225
  rotation_matrix=rotation,
1233
1226
  out=template_rot,
1234
1227
  out_mask=temp,
1235
- use_geometric_center=False,
1228
+ use_geometric_center=True,
1236
1229
  order=interpolation_order,
1237
1230
  )
1238
1231
 
@@ -1307,7 +1300,7 @@ def mcc_scoring(
1307
1300
  return callback
1308
1301
 
1309
1302
 
1310
- def _setup_template_filter(
1303
+ def _setup_template_filter_apply_target_filter(
1311
1304
  matching_data: MatchingData,
1312
1305
  rfftn: Callable,
1313
1306
  irfftn: Callable,
@@ -1317,28 +1310,36 @@ def _setup_template_filter(
1317
1310
  filter_template = isinstance(matching_data.template_filter, Compose)
1318
1311
  filter_target = isinstance(matching_data.target_filter, Compose)
1319
1312
 
1320
- template_filter = backend.full(
1321
- shape=(1,), fill_value=1, dtype=backend._default_dtype
1322
- )
1313
+ template_filter = backend.full(shape=(1,), fill_value=1, dtype=backend._float_dtype)
1323
1314
 
1324
1315
  if not filter_template and not filter_target:
1325
1316
  return template_filter
1326
1317
 
1327
1318
  target_temp = backend.astype(
1328
- backend.topleft_pad(matching_data.target, fast_shape), backend._default_dtype
1319
+ backend.topleft_pad(matching_data.target, fast_shape), backend._float_dtype
1329
1320
  )
1330
1321
  target_temp_ft = backend.preallocate_array(fast_ft_shape, backend._complex_dtype)
1331
- rfftn(target_temp, target_temp_ft)
1332
1322
 
1323
+ filter_shape = backend.multiply(fast_ft_shape, 1 - matching_data._batch_mask)
1324
+ filter_shape[filter_shape == 0] = 1
1325
+ fast_shape = backend.multiply(fast_shape, 1 - matching_data._batch_mask)
1326
+ fast_shape = fast_shape[fast_shape != 0]
1327
+
1328
+ fast_shape = tuple(int(x) for x in fast_shape)
1329
+ filter_shape = tuple(int(x) for x in filter_shape)
1330
+
1331
+ rfftn(target_temp, target_temp_ft)
1333
1332
  if isinstance(matching_data.template_filter, Compose):
1334
1333
  template_filter = matching_data.template_filter(
1335
1334
  shape=fast_shape,
1336
1335
  return_real_fourier=True,
1337
1336
  shape_is_real_fourier=False,
1338
1337
  data_rfft=target_temp_ft,
1338
+ batch_dimension=matching_data._target_dims,
1339
1339
  )
1340
1340
  template_filter = template_filter["data"]
1341
1341
  template_filter[tuple(0 for _ in range(template_filter.ndim))] = 0
1342
+ template_filter = backend.reshape(template_filter, filter_shape)
1342
1343
 
1343
1344
  if isinstance(matching_data.target_filter, Compose):
1344
1345
  target_filter = matching_data.target_filter(
@@ -1347,8 +1348,10 @@ def _setup_template_filter(
1347
1348
  shape_is_real_fourier=False,
1348
1349
  data_rfft=target_temp_ft,
1349
1350
  weight_type=None,
1351
+ batch_dimension=matching_data._target_dims,
1350
1352
  )
1351
1353
  target_filter = target_filter["data"]
1354
+ target_filter = backend.reshape(target_filter, filter_shape)
1352
1355
  backend.multiply(target_temp_ft, target_filter, out=target_temp_ft)
1353
1356
 
1354
1357
  irfftn(target_temp_ft, target_temp)
@@ -1395,8 +1398,7 @@ def scan(
1395
1398
  **kwargs,
1396
1399
  ) -> Tuple:
1397
1400
  """
1398
- Perform template matching between target and template and sample
1399
- different rotations of template.
1401
+ Perform template matching.
1400
1402
 
1401
1403
  Parameters
1402
1404
  ----------
@@ -1420,25 +1422,44 @@ def scan(
1420
1422
  Order of spline interpolation for rotations.
1421
1423
  jobs_per_callback_class : int, optional
1422
1424
  How many jobs should be processed by a single callback_class instance,
1423
- if ones is provided.
1425
+ if one is provided.
1424
1426
  **kwargs : various
1425
- Additional arguments.
1427
+ Additional keyword arguments.
1426
1428
 
1427
1429
  Returns
1428
1430
  -------
1429
1431
  Tuple
1430
1432
  The merged results from callback_class if provided otherwise None.
1433
+
1434
+ Examples
1435
+ --------
1436
+ Schematically, using :py:meth:`scan` is similar to :py:meth:`scan_subsets`,
1437
+ with the distinction that the objects contained in ``matching_data`` are not
1438
+ split and the search is only parallelized over angles.
1439
+ Assuming you have followed the example in :py:meth:`scan_subsets`, :py:meth:`scan`
1440
+ can be invoked like so
1441
+
1442
+ >>> from tme.matching_exhaustive import scan
1443
+ >>> results = scan(
1444
+ >>> matching_data = matching_data,
1445
+ >>> matching_score = matching_score,
1446
+ >>> matching_setup = matching_setup,
1447
+ >>> callback_class = callback_class,
1448
+ >>> callback_class_args = callback_class_args,
1449
+ >>> )
1450
+
1431
1451
  """
1432
1452
  matching_data.to_backend()
1433
1453
  shape_diff = backend.subtract(
1434
1454
  matching_data._output_target_shape, matching_data._output_template_shape
1435
1455
  )
1436
- shape_diff = backend.multiply(shape_diff, ~matching_data._batch_mask)
1456
+ shape_diff = backend.multiply(shape_diff, 1 - matching_data._batch_mask)
1457
+
1437
1458
  if backend.sum(shape_diff < 0) and not pad_fourier:
1438
1459
  warnings.warn(
1439
1460
  "Target is larger than template and Fourier padding is turned off."
1440
- " This can lead to shifted results. You can swap template and target, or"
1441
- " zero-pad the target."
1461
+ " This can lead to shifted results. You can swap template and target,"
1462
+ " zero-pad the target or turn off template centering."
1442
1463
  )
1443
1464
  fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
1444
1465
  pad_fourier=pad_fourier
@@ -1448,18 +1469,18 @@ def scan(
1448
1469
  rfftn, irfftn = backend.build_fft(
1449
1470
  fast_shape=fast_shape,
1450
1471
  fast_ft_shape=fast_ft_shape,
1451
- real_dtype=matching_data._default_dtype,
1452
- complex_dtype=matching_data._complex_dtype,
1472
+ real_dtype=backend._float_dtype,
1473
+ complex_dtype=backend._complex_dtype,
1453
1474
  fftargs=fftargs,
1454
1475
  )
1455
1476
 
1456
- # template_filter = _setup_template_filter(
1457
- # matching_data=matching_data,
1458
- # rfftn=rfftn,
1459
- # irfftn=irfftn,
1460
- # fast_shape=fast_shape,
1461
- # fast_ft_shape=fast_ft_shape,
1462
- # )
1477
+ template_filter = _setup_template_filter_apply_target_filter(
1478
+ matching_data=matching_data,
1479
+ rfftn=rfftn,
1480
+ irfftn=irfftn,
1481
+ fast_shape=fast_shape,
1482
+ fast_ft_shape=fast_ft_shape,
1483
+ )
1463
1484
 
1464
1485
  setup = matching_setup(
1465
1486
  rfftn=rfftn,
@@ -1470,31 +1491,14 @@ def scan(
1470
1491
  target_mask=matching_data.target_mask,
1471
1492
  fast_shape=fast_shape,
1472
1493
  fast_ft_shape=fast_ft_shape,
1473
- real_dtype=matching_data._default_dtype,
1474
- complex_dtype=matching_data._complex_dtype,
1475
1494
  callback_class=callback_class,
1476
1495
  callback_class_args=callback_class_args,
1477
1496
  **kwargs,
1478
1497
  )
1479
1498
  rfftn, irfftn = None, None
1480
1499
 
1481
-
1482
- template_filter, preprocessor = None, Preprocessor()
1483
- for method, parameters in matching_data.template_filter.items():
1484
- parameters["shape"] = fast_shape
1485
- parameters["omit_negative_frequencies"] = True
1486
- out = preprocessor.apply_method(method=method, parameters=parameters)
1487
- if template_filter is None:
1488
- template_filter = out
1489
- np.multiply(template_filter, out, out=template_filter)
1490
-
1491
- if template_filter is None:
1492
- template_filter = backend.full(
1493
- shape=(1,), fill_value=1, dtype=backend._default_dtype
1494
- )
1495
-
1496
1500
  template_filter = backend.to_backend_array(template_filter)
1497
- template_filter = backend.astype(template_filter, backend._default_dtype)
1501
+ template_filter = backend.astype(template_filter, backend._float_dtype)
1498
1502
  template_filter_buffer = backend.arr_to_sharedarr(
1499
1503
  arr=template_filter,
1500
1504
  shared_memory_handler=kwargs.get("shared_memory_handler", None),
@@ -1520,7 +1524,6 @@ def scan(
1520
1524
  if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
1521
1525
  convolution_mode = "valid"
1522
1526
 
1523
-
1524
1527
  callback_class_args["fourier_shift"] = fourier_shift
1525
1528
  callback_class_args["convolution_mode"] = convolution_mode
1526
1529
  callback_class_args["targetshape"] = setup["targetshape"]
@@ -1530,9 +1533,9 @@ def scan(
1530
1533
  callback_classes = [
1531
1534
  class_name(
1532
1535
  score_space_shape=fast_shape,
1533
- score_space_dtype=matching_data._default_dtype,
1536
+ score_space_dtype=backend._float_dtype,
1534
1537
  shared_memory_handler=kwargs.get("shared_memory_handler", None),
1535
- rotation_space_dtype=backend._default_dtype_int,
1538
+ rotation_space_dtype=backend._int_dtype,
1536
1539
  **callback_class_args,
1537
1540
  )
1538
1541
  for class_name in callback_classes
@@ -1571,23 +1574,31 @@ def scan(
1571
1574
 
1572
1575
  callbacks = callbacks[0:n_callback_classes]
1573
1576
  callbacks = [
1574
- tuple(callback._postprocess(
1575
- fourier_shift = fourier_shift,
1576
- convolution_mode = convolution_mode,
1577
- targetshape = setup["targetshape"],
1578
- templateshape = setup["templateshape"],
1579
- shared_memory_handler=kwargs.get("shared_memory_handler", None)
1580
- )) if hasattr(callback, "_postprocess") else tuple(callback)
1581
- for callback in callbacks if callback is not None
1577
+ tuple(
1578
+ callback._postprocess(
1579
+ fourier_shift=fourier_shift,
1580
+ convolution_mode=convolution_mode,
1581
+ targetshape=setup["targetshape"],
1582
+ templateshape=setup["templateshape"],
1583
+ shared_memory_handler=kwargs.get("shared_memory_handler", None),
1584
+ )
1585
+ )
1586
+ if hasattr(callback, "_postprocess")
1587
+ else tuple(callback)
1588
+ for callback in callbacks
1589
+ if callback is not None
1582
1590
  ]
1583
1591
  backend.free_cache()
1584
1592
 
1585
1593
  merged_callback = None
1586
1594
  if callback_class is not None:
1595
+ score_indices = None
1596
+ if hasattr(matching_data, "indices"):
1597
+ score_indices = matching_data.indices
1587
1598
  merged_callback = callback_class.merge(
1588
1599
  callbacks,
1589
1600
  **callback_class_args,
1590
- score_indices=matching_data.indices,
1601
+ score_indices=score_indices,
1591
1602
  inner_merge=True,
1592
1603
  )
1593
1604
 
@@ -1610,14 +1621,14 @@ def scan_subsets(
1610
1621
  **kwargs,
1611
1622
  ) -> Tuple:
1612
1623
  """
1613
- Wrapper around :py:meth:`scan` that supports template matching on splits
1614
- of template and target.
1624
+ Wrapper around :py:meth:`scan` that supports matching on splits
1625
+ of ``matching_data``.
1615
1626
 
1616
1627
  Parameters
1617
1628
  ----------
1618
- matching_data : MatchingData
1619
- Template matching data.
1620
- matching_func : type
1629
+ matching_data : :py:class:`tme.matching_data.MatchingData`
1630
+ MatchingData instance containing relevant data.
1631
+ matching_setup : type
1621
1632
  Function pointer to setup function.
1622
1633
  matching_score : type
1623
1634
  Function pointer to scoring function.
@@ -1626,11 +1637,15 @@ def scan_subsets(
1626
1637
  callback_class_args : dict, optional
1627
1638
  Arguments passed to the callback_class. Default is an empty dictionary.
1628
1639
  job_schedule : tuple of int, optional
1629
- Schedule of jobs. Default is (1, 1).
1640
+ Job scheduling scheme, default is (1, 1). First value corresponds
1641
+ to the number of splits that are processed in parallel, the second
1642
+ to the number of angles evaluated in parallel on each split.
1630
1643
  target_splits : dict, optional
1631
- Splits for target. Default is an empty dictionary, i.e. no splits
1644
+ Splits for target. Default is an empty dictionary, i.e. no splits.
1645
+ See :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
1632
1646
  template_splits : dict, optional
1633
1647
  Splits for template. Default is an empty dictionary, i.e. no splits.
1648
+ See :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
1634
1649
  pad_target_edges : bool, optional
1635
1650
  Whether to pad the target boundaries by half the template shape
1636
1651
  along each axis.
@@ -1652,10 +1667,74 @@ def scan_subsets(
1652
1667
  -------
1653
1668
  Tuple
1654
1669
  The merged results from callback_class if provided otherwise None.
1670
+
1671
+ Examples
1672
+ --------
1673
+ All data relevant to template matching will be contained in ``matching_data``, which
1674
+ is a :py:class:`tme.matching_data.MatchingData` instance and can be created like so
1675
+
1676
+ >>> import numpy as np
1677
+ >>> from tme.matching_data import MatchingData
1678
+ >>> from tme.matching_utils import get_rotation_matrices
1679
+ >>> target = np.random.rand(50,40,60)
1680
+ >>> template = target[15:25, 10:20, 30:40]
1681
+ >>> matching_data = MatchingData(target, template)
1682
+ >>> matching_data.rotations = get_rotation_matrices(
1683
+ >>> angular_sampling = 60, dim = target.ndim
1684
+ >>> )
1685
+
1686
+ The template matching procedure is determined by ``matching_setup`` and
1687
+ ``matching_score``, which are unique to each score. In the following,
1688
+ we will be using the `FLCSphericalMask` score, which is composed of
1689
+ :py:meth:`flcSphericalMask_setup` and :py:meth:`corr_scoring`
1690
+
1691
+ >>> from tme.matching_exhaustive import MATCHING_EXHAUSTIVE_REGISTER
1692
+ >>> funcs = MATCHING_EXHAUSTIVE_REGISTER.get("FLCSphericalMask")
1693
+ >>> matching_setup, matching_score = funcs
1694
+
1695
+ Computed scores are flexibly analyzed by being passed through an analyzer. In the
1696
+ following, we will use :py:class:`tme.analyzer.MaxScoreOverRotations` to
1697
+ aggregate sores over rotations
1698
+
1699
+ >>> from tme.analyzer import MaxScoreOverRotations
1700
+ >>> callback_class = MaxScoreOverRotations
1701
+ >>> callback_class_args = {"score_threshold" : 0}
1702
+
1703
+ In case the entire template matching problem does not fit into memory, we can
1704
+ determine the splitting procedure. In this case, we halv the first axis of the target
1705
+ once. Splitting and ``job_schedule`` is typically computed using
1706
+ :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
1707
+
1708
+ >>> target_splits = {0 : 1}
1709
+
1710
+ Finally, we can perform template matching. Note that the data
1711
+ contained in ``matching_data`` will be destroyed when running the following
1712
+
1713
+ >>> from tme.matching_exhaustive import scan_subsets
1714
+ >>> results = scan_subsets(
1715
+ >>> matching_data = matching_data,
1716
+ >>> matching_score = matching_score,
1717
+ >>> matching_setup = matching_setup,
1718
+ >>> callback_class = callback_class,
1719
+ >>> callback_class_args = callback_class_args,
1720
+ >>> target_splits = target_splits,
1721
+ >>> )
1722
+
1723
+ The retuned ``results`` tuple contains the output of the chosen analyzer.
1724
+
1725
+ See Also
1726
+ --------
1727
+ :py:meth:`tme.matching_utils.compute_parallelization_schedule`
1655
1728
  """
1656
1729
  target_splits = split_numpy_array_slices(
1657
1730
  matching_data._target.shape, splits=target_splits
1658
1731
  )
1732
+ if (len(target_splits) > 1) and not pad_target_edges:
1733
+ warnings.warn(
1734
+ "Target splitting without padding target edges leads to unreliable "
1735
+ "similarity estimates around the split border."
1736
+ )
1737
+
1659
1738
  template_splits = split_numpy_array_slices(
1660
1739
  matching_data._template.shape, splits=template_splits
1661
1740
  )