pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/match_template.py +28 -39
  2. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/postprocess.py +35 -21
  3. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/preprocessor_gui.py +95 -24
  4. pytme-0.3.1.post1.data/scripts/pytme_runner.py +1223 -0
  5. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/METADATA +5 -7
  6. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/RECORD +55 -48
  7. scripts/extract_candidates.py +118 -99
  8. scripts/match_template.py +28 -39
  9. scripts/postprocess.py +35 -21
  10. scripts/preprocessor_gui.py +95 -24
  11. scripts/pytme_runner.py +644 -190
  12. scripts/refine_matches.py +156 -386
  13. tests/data/.DS_Store +0 -0
  14. tests/data/Blurring/.DS_Store +0 -0
  15. tests/data/Maps/.DS_Store +0 -0
  16. tests/data/Raw/.DS_Store +0 -0
  17. tests/data/Structures/.DS_Store +0 -0
  18. tests/preprocessing/test_utils.py +18 -0
  19. tests/test_analyzer.py +2 -3
  20. tests/test_backends.py +3 -9
  21. tests/test_density.py +0 -1
  22. tests/test_extensions.py +0 -1
  23. tests/test_matching_utils.py +10 -60
  24. tests/test_rotations.py +1 -1
  25. tme/__version__.py +1 -1
  26. tme/analyzer/_utils.py +4 -4
  27. tme/analyzer/aggregation.py +35 -15
  28. tme/analyzer/peaks.py +11 -10
  29. tme/backends/_jax_utils.py +26 -13
  30. tme/backends/_numpyfftw_utils.py +270 -0
  31. tme/backends/cupy_backend.py +16 -55
  32. tme/backends/jax_backend.py +76 -37
  33. tme/backends/matching_backend.py +17 -51
  34. tme/backends/mlx_backend.py +1 -27
  35. tme/backends/npfftw_backend.py +71 -65
  36. tme/backends/pytorch_backend.py +1 -26
  37. tme/density.py +2 -6
  38. tme/extensions.cpython-311-darwin.so +0 -0
  39. tme/filters/ctf.py +22 -21
  40. tme/filters/wedge.py +10 -7
  41. tme/mask.py +341 -0
  42. tme/matching_data.py +31 -19
  43. tme/matching_exhaustive.py +37 -47
  44. tme/matching_optimization.py +2 -1
  45. tme/matching_scores.py +229 -411
  46. tme/matching_utils.py +73 -422
  47. tme/memory.py +1 -1
  48. tme/orientations.py +13 -8
  49. tme/rotations.py +1 -1
  50. pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
  51. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/estimate_memory_usage.py +0 -0
  52. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/preprocess.py +0 -0
  53. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/WHEEL +0 -0
  54. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/entry_points.txt +0 -0
  55. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/licenses/LICENSE +0 -0
  56. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/top_level.txt +0 -0
tme/matching_scores.py CHANGED
@@ -18,138 +18,9 @@ from .matching_utils import (
18
18
  conditional_execute,
19
19
  identity,
20
20
  normalize_template,
21
- _normalize_template_overflow_safe,
22
21
  )
23
22
 
24
23
 
25
- def _shape_match(shape1: Tuple[int], shape2: Tuple[int]) -> bool:
26
- """
27
- Determine whether ``shape1`` is equal to ``shape2``.
28
-
29
- Parameters
30
- ----------
31
- shape1, shape2 : tuple of ints
32
- Shapes to compare.
33
-
34
- Returns
35
- -------
36
- Bool
37
- ``shape1`` is equal to ``shape2``.
38
- """
39
- if len(shape1) != len(shape2):
40
- return False
41
- return shape1 == shape2
42
-
43
-
44
- def _create_filter_func(
45
- fwd_shape: Tuple[int],
46
- inv_shape: Tuple[int],
47
- arr_shape: Tuple[int],
48
- arr_filter: BackendArray,
49
- arr_ft_shape: Tuple[int],
50
- inv_output_shape: Tuple[int],
51
- real_dtype: type,
52
- cmpl_dtype: type,
53
- fwd_axes=None,
54
- inv_axes=None,
55
- rfftn: Callable = None,
56
- irfftn: Callable = None,
57
- ) -> Callable:
58
- """
59
- Configure template filtering function for Fourier transforms.
60
-
61
- Conceptually we distinguish between three cases. The base case
62
- is that both template and the corresponding filter have the same
63
- shape. Padding is used when the template filter is larger than
64
- the template, for instance to better resolve Fourier filters. Finally
65
- this function also handles the case when a filter is supposed to be
66
- broadcasted over the template batch dimension.
67
-
68
- Parameters
69
- ----------
70
- fwd_shape : tuple of ints
71
- Input shape of rfftn.
72
- inv_shape : tuple of ints
73
- Input shape of irfftn.
74
- arr_shape : tuple of ints
75
- Shape of the array to be filtered.
76
- arr_ft_shape : tuple of ints
77
- Shape of the Fourier transform of the array.
78
- arr_filter : BackendArray
79
- Precomputed filter to apply in the frequency domain.
80
- rfftn : Callable, optional
81
- Foward Fourier transform.
82
- irfftn : Callable, optional
83
- Inverse Fourier transform.
84
-
85
- Returns
86
- -------
87
- Callable
88
- Filter function with parameters template, ft_temp and template_filter.
89
- """
90
- if be.size(arr_filter) == 1:
91
- return conditional_execute(identity, identity, False)
92
-
93
- filter_shape = tuple(int(x) for x in arr_filter.shape)
94
- try:
95
- product_ft_shape = np.broadcast_shapes(arr_ft_shape, filter_shape)
96
- except ValueError:
97
- product_ft_shape, inv_output_shape = filter_shape, arr_shape
98
-
99
- rfft_valid = _shape_match(arr_shape, fwd_shape)
100
- rfft_valid = rfft_valid and _shape_match(product_ft_shape, inv_shape)
101
- rfft_valid = rfft_valid and rfftn is not None and irfftn is not None
102
-
103
- # FTTs were not or built for the wrong shape
104
- if not rfft_valid:
105
- _fwd_shape = arr_shape
106
- if all(x > y for x, y in zip(arr_shape, product_ft_shape)):
107
- _fwd_shape = fwd_shape
108
-
109
- rfftn, irfftn = be.build_fft(
110
- fwd_shape=_fwd_shape,
111
- inv_shape=product_ft_shape,
112
- real_dtype=real_dtype,
113
- cmpl_dtype=cmpl_dtype,
114
- inv_output_shape=inv_output_shape,
115
- fwd_axes=fwd_axes,
116
- inv_axes=inv_axes,
117
- )
118
-
119
- # Default case, all shapes are correctly matched
120
- def _apply_filter(template, ft_temp, template_filter):
121
- ft_temp = rfftn(template, ft_temp)
122
- ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
123
- return irfftn(ft_temp, template)
124
-
125
- if not _shape_match(arr_ft_shape, filter_shape):
126
- real_subset = tuple(slice(0, x) for x in arr_shape)
127
- _template = be.zeros(arr_shape, be._float_dtype)
128
- _ft_temp = be.zeros(product_ft_shape, be._complex_dtype)
129
-
130
- # Arr is padded, filter is not
131
- def _apply_filter_subset(template, ft_temp, template_filter):
132
- # TODO: Benchmark this
133
- _template[:] = template[real_subset]
134
- template[real_subset] = _apply_filter(_template, _ft_temp, template_filter)
135
- return template
136
-
137
- # Filter application requires a broadcasting operation
138
- def _apply_filter_broadcast(template, ft_temp, template_filter):
139
- _ft_prod = rfftn(template, _ft_temp2)
140
- _ft_res = be.multiply(_ft_prod, template_filter, out=_ft_temp)
141
- return irfftn(_ft_res, _template)
142
-
143
- if any(x > y and y == 1 for x, y in zip(filter_shape, arr_ft_shape)):
144
- _template = be.zeros(inv_output_shape, be._float_dtype)
145
- _ft_temp2 = be.zeros((1, *product_ft_shape[1:]), be._complex_dtype)
146
- return _apply_filter_broadcast
147
-
148
- return _apply_filter_subset
149
-
150
- return _apply_filter
151
-
152
-
153
24
  def cc_setup(
154
25
  matching_data: type,
155
26
  fast_shape: Tuple[int],
@@ -176,8 +47,6 @@ def cc_setup(
176
47
  axes = matching_data._batch_axis(matching_data._batch_mask)
177
48
 
178
49
  ret = {
179
- "fast_shape": fast_shape,
180
- "fast_ft_shape": fast_ft_shape,
181
50
  "template": be.to_sharedarr(matching_data.template, shm_handler),
182
51
  "ft_target": be.to_sharedarr(be.rfftn(target_pad, axes=axes), shm_handler),
183
52
  "inv_denominator": be.to_sharedarr(
@@ -219,8 +88,8 @@ def lcc_setup(matching_data, **kwargs) -> Dict:
219
88
  for subset in subsets:
220
89
  template[subset] = laplace(template[subset], mode="wrap")
221
90
 
222
- matching_data._target = target
223
- matching_data._template = template
91
+ matching_data._target = be.to_backend_array(target)
92
+ matching_data._template = be.to_backend_array(template)
224
93
 
225
94
  return cc_setup(matching_data=matching_data, **kwargs)
226
95
 
@@ -316,8 +185,6 @@ def corr_setup(
316
185
  denominator = be.multiply(denominator, mask, out=denominator)
317
186
 
318
187
  ret = {
319
- "fast_shape": fast_shape,
320
- "fast_ft_shape": fast_ft_shape,
321
188
  "template": be.to_sharedarr(template, shm_handler),
322
189
  "ft_target": be.to_sharedarr(ft_target, shm_handler),
323
190
  "inv_denominator": be.to_sharedarr(denominator, shm_handler),
@@ -376,8 +243,6 @@ def flc_setup(
376
243
  ft_target2 = be.rfftn(target_pad, axes=data_axes)
377
244
 
378
245
  ret = {
379
- "fast_shape": fast_shape,
380
- "fast_ft_shape": fast_ft_shape,
381
246
  "template": be.to_sharedarr(matching_data.template, shm_handler),
382
247
  "template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
383
248
  "ft_target": be.to_sharedarr(ft_target, shm_handler),
@@ -438,8 +303,6 @@ def flcSphericalMask_setup(
438
303
 
439
304
  temp2 = be.norm_scores(1, temp2, temp, n_obs, be.eps(be._float_dtype), temp2)
440
305
  ret = {
441
- "fast_shape": fast_shape,
442
- "fast_ft_shape": fast_ft_shape,
443
306
  "template": be.to_sharedarr(matching_data.template, shm_handler),
444
307
  "template_mask": be.to_sharedarr(template_mask, shm_handler),
445
308
  "ft_target": be.to_sharedarr(ft_target, shm_handler),
@@ -461,21 +324,14 @@ def mcc_setup(
461
324
  Setup function for :py:meth:`mcc_scoring`.
462
325
  """
463
326
  target, target_mask = matching_data.target, matching_data.target_mask
464
- target = be.multiply(target, target_mask > 0, out=target)
327
+ target = be.multiply(target, target_mask, out=target)
465
328
 
466
- target = be.topleft_pad(
467
- target,
468
- matching_data._batch_shape(fast_shape, matching_data._template_batch),
469
- )
470
- target_mask = be.topleft_pad(
471
- target_mask,
472
- matching_data._batch_shape(fast_shape, matching_data._template_batch),
473
- )
474
329
  ax = matching_data._batch_axis(matching_data._batch_mask)
330
+ shape = matching_data._batch_shape(fast_shape, matching_data._template_batch)
331
+ target = be.topleft_pad(target, shape)
332
+ target_mask = be.topleft_pad(target_mask, shape)
475
333
 
476
334
  ret = {
477
- "fast_shape": fast_shape,
478
- "fast_ft_shape": fast_ft_shape,
479
335
  "template": be.to_sharedarr(matching_data.template, shm_handler),
480
336
  "template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
481
337
  "ft_target": be.to_sharedarr(be.rfftn(target, axes=ax), shm_handler),
@@ -500,6 +356,7 @@ def corr_scoring(
500
356
  callback: CallbackClass,
501
357
  interpolation_order: int,
502
358
  template_mask: shm_type = None,
359
+ score_mask: shm_type = None,
503
360
  ) -> CallbackClass:
504
361
  """
505
362
  Calculates a normalized cross-correlation between a target f and a template g.
@@ -538,70 +395,45 @@ def corr_scoring(
538
395
  Spline order for template rotations.
539
396
  template_mask : Union[Tuple[type, tuple of ints, type], BackendArray], optional
540
397
  Template mask data buffer, its shape and datatype, None by default.
398
+ score_mask : Union[Tuple[type, tuple of ints, type], BackendArray], optional
399
+ Score mask data buffer, its shape and datatype, None by default.
541
400
 
542
401
  Returns
543
402
  -------
544
- Optional[CallbackClass]
545
- ``callback`` if provided otherwise None.
403
+ CallbackClass
546
404
  """
547
405
  template = be.from_sharedarr(template)
548
406
  ft_target = be.from_sharedarr(ft_target)
549
407
  inv_denominator = be.from_sharedarr(inv_denominator)
550
408
  numerator = be.from_sharedarr(numerator)
551
409
  template_filter = be.from_sharedarr(template_filter)
410
+ score_mask = be.from_sharedarr(score_mask)
552
411
 
553
- norm_func, norm_template, mask_sum = normalize_template, False, 1
412
+ n_obs = None
554
413
  if template_mask is not None:
555
414
  template_mask = be.from_sharedarr(template_mask)
556
- norm_template, mask_sum = True, be.sum(template_mask)
557
- if be.datatype_bytes(template_mask.dtype) == 2:
558
- norm_func = _normalize_template_overflow_safe
559
- mask_sum = be.sum(be.astype(template_mask, be._overflow_safe_dtype))
560
-
561
- callback_func = conditional_execute(callback, callback is not None)
562
- norm_template = conditional_execute(norm_func, norm_template)
563
- norm_numerator = conditional_execute(
564
- be.subtract, identity, _shape_match(numerator.shape, fast_shape)
565
- )
566
- norm_denominator = conditional_execute(
567
- be.multiply, identity, _shape_match(inv_denominator.shape, fast_shape)
568
- )
415
+ n_obs = be.sum(template_mask) if template_mask is not None else None
416
+
417
+ norm_template = conditional_execute(normalize_template, n_obs is not None)
418
+ norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
419
+ norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1))
420
+ norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
569
421
 
570
422
  arr = be.zeros(fast_shape, be._float_dtype)
571
423
  ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
572
-
573
- _fftargs = {
574
- "real_dtype": be._float_dtype,
575
- "cmpl_dtype": be._complex_dtype,
576
- "inv_output_shape": fast_shape,
577
- "fwd_axes": None,
578
- "inv_axes": None,
579
- "inv_shape": fast_ft_shape,
580
- "temp_fwd": arr,
581
- }
582
-
583
- _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
584
- rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
585
- _ = _fftargs.pop("temp_fwd", None)
424
+ template_rot = be.zeros(template.shape, be._float_dtype)
586
425
 
587
426
  template_filter_func = _create_filter_func(
588
427
  arr_shape=template.shape,
589
- arr_ft_shape=fast_ft_shape,
590
- arr_filter=template_filter,
591
- rfftn=rfftn,
592
- irfftn=irfftn,
593
- **_fftargs,
428
+ filter_shape=template_filter.shape,
594
429
  )
595
430
 
596
431
  center = be.divide(be.to_backend_array(template.shape) - 1, 2)
597
432
  unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
598
-
599
- template_rot = be.zeros(template.shape, be._float_dtype)
600
433
  for index in range(rotations.shape[0]):
601
- # d+1, d+1 rigid transform matrix from d,d rotation matrix
602
434
  rotation = rotations[index]
603
435
  matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
604
- template_rot, _ = be.rigid_transform(
436
+ _ = be.rigid_transform(
605
437
  arr=template,
606
438
  rotation_matrix=matrix,
607
439
  out=template_rot,
@@ -610,18 +442,19 @@ def corr_scoring(
610
442
  )
611
443
 
612
444
  template_rot = template_filter_func(template_rot, ft_temp, template_filter)
613
- norm_template(template_rot, template_mask, mask_sum)
445
+ norm_template(template_rot, template_mask, n_obs)
614
446
 
615
447
  arr = be.fill(arr, 0)
616
448
  arr[unpadded_slice] = template_rot
617
449
 
618
- ft_temp = rfftn(arr, ft_temp)
619
- ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp)
620
- arr = irfftn(ft_temp, arr)
450
+ ft_temp = be.rfftn(arr, s=fast_shape, out=ft_temp)
451
+ arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
452
+
453
+ arr = norm_sub(arr, numerator, out=arr)
454
+ arr = norm_mul(arr, inv_denominator, out=arr)
455
+ arr = norm_mask(arr, score_mask, out=arr)
621
456
 
622
- arr = norm_numerator(arr, numerator, out=arr)
623
- arr = norm_denominator(arr, inv_denominator, out=arr)
624
- callback_func(arr, rotation_matrix=rotation)
457
+ callback(arr, rotation_matrix=rotation)
625
458
 
626
459
  return callback
627
460
 
@@ -637,6 +470,7 @@ def flc_scoring(
637
470
  rotations: BackendArray,
638
471
  callback: CallbackClass,
639
472
  interpolation_order: int,
473
+ score_mask: shm_type = None,
640
474
  ) -> CallbackClass:
641
475
  """
642
476
  Computes a normalized cross-correlation between ``target`` (f),
@@ -682,6 +516,10 @@ def flc_scoring(
682
516
  interpolation_order : int
683
517
  Spline order for template rotations.
684
518
 
519
+ Returns
520
+ -------
521
+ CallbackClass
522
+
685
523
  References
686
524
  ----------
687
525
  .. [1] Hrabe T. et al, J. Struct. Biol. 178, 177 (2012).
@@ -692,69 +530,56 @@ def flc_scoring(
692
530
  ft_target = be.from_sharedarr(ft_target)
693
531
  ft_target2 = be.from_sharedarr(ft_target2)
694
532
  template_filter = be.from_sharedarr(template_filter)
533
+ score_mask = be.from_sharedarr(score_mask)
695
534
 
696
535
  arr = be.zeros(fast_shape, float_dtype)
697
536
  temp = be.zeros(fast_shape, float_dtype)
698
537
  temp2 = be.zeros(fast_shape, float_dtype)
699
538
  ft_temp = be.zeros(fast_ft_shape, complex_dtype)
700
539
  ft_denom = be.zeros(fast_ft_shape, complex_dtype)
540
+ template_rot = be.zeros(template.shape, be._float_dtype)
541
+ template_mask_rot = be.zeros(template.shape, be._float_dtype)
701
542
 
702
- _fftargs = {
703
- "real_dtype": be._float_dtype,
704
- "cmpl_dtype": be._complex_dtype,
705
- "inv_output_shape": fast_shape,
706
- "fwd_axes": None,
707
- "inv_axes": None,
708
- "inv_shape": fast_ft_shape,
709
- "temp_fwd": arr,
710
- }
711
-
712
- _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
713
- rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
714
- _ = _fftargs.pop("temp_fwd", None)
715
-
716
- template_filter_func = _create_filter_func(
717
- arr_shape=template.shape,
718
- arr_ft_shape=fast_ft_shape,
719
- arr_filter=template_filter,
720
- rfftn=rfftn,
721
- irfftn=irfftn,
722
- **_fftargs,
723
- )
543
+ tmpl_filter_func = _create_filter_func(template.shape, template_filter.shape)
544
+ norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
724
545
 
725
546
  eps = be.eps(float_dtype)
726
- callback_func = conditional_execute(callback, callback is not None)
547
+ center = be.divide(be.to_backend_array(template.shape) - 1, 2)
548
+ unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
727
549
  for index in range(rotations.shape[0]):
728
550
  rotation = rotations[index]
729
- arr = be.fill(arr, 0)
730
- temp = be.fill(temp, 0)
731
- arr, temp = be.rigid_transform(
551
+ matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
552
+ _ = be.rigid_transform(
732
553
  arr=template,
733
554
  arr_mask=template_mask,
734
- rotation_matrix=rotation,
735
- out=arr,
736
- out_mask=temp,
555
+ rotation_matrix=matrix,
556
+ out=template_rot,
557
+ out_mask=template_mask_rot,
737
558
  use_geometric_center=True,
738
559
  order=interpolation_order,
739
560
  cache=True,
740
561
  )
741
562
 
742
- n_obs = be.sum(temp)
743
- arr = template_filter_func(arr, ft_temp, template_filter)
744
- arr = normalize_template(arr, temp, n_obs, axis=None)
563
+ n_obs = be.sum(template_mask_rot)
564
+ template_rot = tmpl_filter_func(template_rot, ft_temp, template_filter)
565
+ template_rot = normalize_template(template_rot, template_mask_rot, n_obs)
745
566
 
746
- ft_temp = rfftn(temp, ft_temp)
747
- ft_denom = be.multiply(ft_target, ft_temp, out=ft_denom)
748
- temp = irfftn(ft_denom, temp)
749
- ft_denom = be.multiply(ft_target2, ft_temp, out=ft_denom)
750
- temp2 = irfftn(ft_denom, temp2)
567
+ arr = be.fill(arr, 0)
568
+ temp = be.fill(temp, 0)
569
+ arr[unpadded_slice] = template_rot
570
+ temp[unpadded_slice] = template_mask_rot
571
+
572
+ ft_temp = be.rfftn(temp, out=ft_temp, s=fast_shape)
573
+ temp = _correlate_fts(ft_target, ft_temp, ft_denom, temp, fast_shape)
574
+ temp2 = _correlate_fts(ft_target2, ft_temp, ft_denom, temp, fast_shape)
751
575
 
752
- ft_temp = rfftn(arr, ft_temp)
753
- ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp)
754
- arr = irfftn(ft_temp, arr)
576
+ ft_temp = be.rfftn(arr, out=ft_temp, s=fast_shape)
577
+ arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
755
578
 
756
579
  arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
757
- callback_func(arr, rotation_matrix=rotation)
580
+ arr = norm_mask(arr, score_mask, out=arr)
581
+
582
+ callback(arr, rotation_matrix=rotation)
758
583
 
759
584
  return callback
760
585
 
@@ -772,6 +597,7 @@ def mcc_scoring(
772
597
  callback: CallbackClass,
773
598
  interpolation_order: int,
774
599
  overlap_ratio: float = 0.3,
600
+ score_mask: shm_type = None,
775
601
  ) -> CallbackClass:
776
602
  """
777
603
  Computes a normalized cross-correlation score between ``target`` (f),
@@ -821,6 +647,10 @@ def mcc_scoring(
821
647
  overlap_ratio : float, optional
822
648
  Required fractional mask overlap, 0.3 by default.
823
649
 
650
+ Returns
651
+ -------
652
+ CallbackClass
653
+
824
654
  References
825
655
  ----------
826
656
  .. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
@@ -846,30 +676,12 @@ def mcc_scoring(
846
676
  temp3 = be.zeros(fast_shape, float_dtype)
847
677
  temp_ft = be.zeros(fast_ft_shape, complex_dtype)
848
678
 
849
- _fftargs = {
850
- "real_dtype": be._float_dtype,
851
- "cmpl_dtype": be._complex_dtype,
852
- "inv_output_shape": fast_shape,
853
- "fwd_axes": None,
854
- "inv_axes": None,
855
- "inv_shape": fast_ft_shape,
856
- "temp_fwd": temp,
857
- }
858
-
859
- _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
860
- rfftn, irfftn = be.build_fft(temp_inv=temp_ft, **_fftargs)
861
- _ = _fftargs.pop("temp_fwd", None)
862
-
863
679
  template_filter_func = _create_filter_func(
864
680
  arr_shape=template.shape,
865
- arr_ft_shape=fast_ft_shape,
866
- arr_filter=template_filter,
867
- rfftn=rfftn,
868
- irfftn=irfftn,
869
- **_fftargs,
681
+ filter_shape=template_filter.shape,
682
+ arr_padded=True,
870
683
  )
871
684
 
872
- callback_func = conditional_execute(callback, callback is not None)
873
685
  for index in range(rotations.shape[0]):
874
686
  rotation = rotations[index]
875
687
  template_rot = be.fill(template_rot, 0)
@@ -888,17 +700,19 @@ def mcc_scoring(
888
700
  template_filter_func(template_rot, temp_ft, template_filter)
889
701
  normalize_template(template_rot, temp, be.sum(temp))
890
702
 
891
- temp_ft = rfftn(template_rot, temp_ft)
892
- temp2 = irfftn(target_mask_ft * temp_ft, temp2)
893
- numerator = irfftn(target_ft * temp_ft, numerator)
703
+ temp_ft = be.rfftn(template_rot, out=temp_ft, s=fast_shape)
704
+ temp2 = be.irfftn(target_mask_ft * temp_ft, out=temp2, s=fast_shape)
705
+ numerator = be.irfftn(target_ft * temp_ft, out=numerator, s=fast_shape)
894
706
 
895
707
  # temp template_mask_rot | temp_ft template_mask_rot_ft
896
708
  # Calculate overlap of masks at every point in the convolution.
897
709
  # Locations with high overlap should not be taken into account.
898
- temp_ft = rfftn(temp, temp_ft)
899
- mask_overlap = irfftn(temp_ft * target_mask_ft, mask_overlap)
710
+ temp_ft = be.rfftn(temp, out=temp_ft, s=fast_shape)
711
+ mask_overlap = be.irfftn(
712
+ temp_ft * target_mask_ft, out=mask_overlap, s=fast_shape
713
+ )
900
714
  be.maximum(mask_overlap, eps, out=mask_overlap)
901
- temp = irfftn(temp_ft * target_ft, temp)
715
+ temp = be.irfftn(temp_ft * target_ft, out=temp, s=fast_shape)
902
716
 
903
717
  be.subtract(
904
718
  numerator,
@@ -908,21 +722,21 @@ def mcc_scoring(
908
722
 
909
723
  # temp_3 = fixed_denom
910
724
  be.multiply(temp_ft, target_ft2, out=temp_ft)
911
- temp3 = irfftn(temp_ft, temp3)
725
+ temp3 = be.irfftn(temp_ft, out=temp3, s=fast_shape)
912
726
  be.subtract(temp3, be.divide(be.square(temp), mask_overlap), out=temp3)
913
727
  be.maximum(temp3, 0.0, out=temp3)
914
728
 
915
729
  # temp = moving_denom
916
- temp_ft = rfftn(be.square(template_rot), temp_ft)
730
+ temp_ft = be.rfftn(be.square(template_rot), out=temp_ft, s=fast_shape)
917
731
  be.multiply(target_mask_ft, temp_ft, out=temp_ft)
918
- temp = irfftn(temp_ft, temp)
732
+ temp = be.irfftn(temp_ft, out=temp, s=fast_shape)
919
733
 
920
734
  be.subtract(temp, be.divide(be.square(temp2), mask_overlap), out=temp)
921
735
  be.maximum(temp, 0.0, out=temp)
922
736
 
923
737
  # temp_2 = denom
924
738
  be.multiply(temp3, temp, out=temp)
925
- be.sqrt(temp, temp2)
739
+ be.sqrt(temp, out=temp2)
926
740
 
927
741
  # Pixels where `denom` is very small will introduce large
928
742
  # numbers after division. To get around this problem,
@@ -938,29 +752,11 @@ def mcc_scoring(
938
752
  mask_overlap, axis=axes, keepdims=True
939
753
  )
940
754
  temp[mask_overlap < number_px_threshold] = 0.0
941
- callback_func(temp, rotation_matrix=rotation)
755
+ callback(temp, rotation_matrix=rotation)
942
756
 
943
757
  return callback
944
758
 
945
759
 
946
- def _format_slice(shape, squeeze_axis):
947
- ret = tuple(
948
- slice(None) if i not in squeeze_axis else 0 for i, _ in enumerate(shape)
949
- )
950
- return ret
951
-
952
-
953
- def _get_batch_dim(target, template):
954
- target_batch, template_batch = [], []
955
- for i in range(len(target.shape)):
956
- if target.shape[i] == 1 and template.shape[i] != 1:
957
- template_batch.append(i)
958
- if target.shape[i] != 1 and template.shape[i] == 1:
959
- target_batch.append(i)
960
-
961
- return target_batch, template_batch
962
-
963
-
964
760
  def flc_scoring2(
965
761
  template: shm_type,
966
762
  template_mask: shm_type,
@@ -972,26 +768,25 @@ def flc_scoring2(
972
768
  rotations: BackendArray,
973
769
  callback: CallbackClass,
974
770
  interpolation_order: int,
771
+ score_mask: shm_type = None,
975
772
  ) -> CallbackClass:
976
- callback_func = conditional_execute(callback, callback is not None)
977
-
978
- # Retrieve objects from shared memory
979
773
  template = be.from_sharedarr(template)
980
774
  template_mask = be.from_sharedarr(template_mask)
981
775
  ft_target = be.from_sharedarr(ft_target)
982
776
  ft_target2 = be.from_sharedarr(ft_target2)
983
777
  template_filter = be.from_sharedarr(template_filter)
778
+ score_mask = be.from_sharedarr(score_mask)
984
779
 
985
- data_axes = None
986
- target_batch, template_batch = _get_batch_dim(ft_target, template)
987
- sqz_cmpl = tuple(1 if i in target_batch else x for i, x in enumerate(fast_ft_shape))
988
- sqz_slice = tuple(slice(0, 1) if x == 1 else slice(None) for x in sqz_cmpl)
780
+ tar_batch, tmpl_batch = _get_batch_dim(ft_target, template)
989
781
 
990
- data_shape = fast_shape
991
- if len(target_batch) or len(template_batch):
992
- batch = (*target_batch, *template_batch)
993
- data_axes = tuple(i for i in range(len(fast_shape)) if i not in batch)
994
- data_shape = tuple(fast_shape[i] for i in data_axes)
782
+ nd = len(fast_shape)
783
+ sqz_slice = tuple(slice(0, 1) if i in tar_batch else slice(None) for i in range(nd))
784
+ tmpl_subset = tuple(0 if i in tar_batch else slice(None) for i in range(nd))
785
+
786
+ axes, shape, batched = None, fast_shape, len(tmpl_batch) > 0
787
+ if len(tar_batch) or len(tmpl_batch):
788
+ axes = tuple(i for i in range(nd) if i not in (*tar_batch, *tmpl_batch))
789
+ shape = tuple(fast_shape[i] for i in axes)
995
790
 
996
791
  arr = be.zeros(fast_shape, be._float_dtype)
997
792
  temp = be.zeros(fast_shape, be._float_dtype)
@@ -999,68 +794,47 @@ def flc_scoring2(
999
794
  ft_denom = be.zeros(fast_ft_shape, be._complex_dtype)
1000
795
 
1001
796
  tmp_sqz, arr_sqz, ft_temp = temp[sqz_slice], arr[sqz_slice], ft_denom[sqz_slice]
1002
- if be.size(template_filter) != 1:
1003
- ret_shape = np.broadcast_shapes(
1004
- sqz_cmpl, tuple(int(x) for x in template_filter.shape)
1005
- )
1006
- ft_temp = be.zeros(ret_shape, be._complex_dtype)
1007
-
1008
- _fftargs = {
1009
- "real_dtype": be._float_dtype,
1010
- "cmpl_dtype": be._complex_dtype,
1011
- "inv_output_shape": fast_shape,
1012
- "fwd_axes": data_axes,
1013
- "inv_axes": data_axes,
1014
- "inv_shape": fast_ft_shape,
1015
- "temp_fwd": arr_sqz if _shape_match(ft_temp.shape, sqz_cmpl) else arr,
1016
- }
1017
-
1018
- # build_fft ignores fwd_shape if temp_fwd is given and serves only for bookkeeping
1019
- _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
1020
- rfftn, irfftn = be.build_fft(temp_inv=ft_denom, **_fftargs)
1021
- _ = _fftargs.pop("temp_fwd", None)
1022
797
 
1023
798
  template_filter_func = _create_filter_func(
1024
799
  arr_shape=template.shape,
1025
- arr_ft_shape=sqz_cmpl,
1026
- arr_filter=template_filter,
1027
- rfftn=rfftn,
1028
- irfftn=irfftn,
1029
- **_fftargs,
800
+ filter_shape=template_filter.shape,
801
+ arr_padded=True,
1030
802
  )
803
+ norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
1031
804
 
1032
805
  eps = be.eps(be._float_dtype)
1033
806
  for index in range(rotations.shape[0]):
1034
807
  rotation = rotations[index]
1035
808
  be.fill(arr, 0)
1036
809
  be.fill(temp, 0)
1037
- arr_sqz, tmp_sqz = be.rigid_transform(
1038
- arr=template,
1039
- arr_mask=template_mask,
810
+
811
+ _, _ = be.rigid_transform(
812
+ arr=template[tmpl_subset],
813
+ arr_mask=template_mask[tmpl_subset],
1040
814
  rotation_matrix=rotation,
1041
- out=arr_sqz,
1042
- out_mask=tmp_sqz,
815
+ out=arr_sqz[tmpl_subset],
816
+ out_mask=tmp_sqz[tmpl_subset],
1043
817
  use_geometric_center=True,
1044
818
  order=interpolation_order,
1045
- cache=True,
1046
- batched=True,
819
+ cache=False,
820
+ batched=batched,
1047
821
  )
1048
- n_obs = be.sum(tmp_sqz, axis=data_axes, keepdims=True)
822
+
823
+ n_obs = be.sum(tmp_sqz, axis=axes, keepdims=True)
1049
824
  arr_norm = template_filter_func(arr_sqz, ft_temp, template_filter)
1050
- arr_norm = normalize_template(arr_norm, tmp_sqz, n_obs, axis=data_axes)
825
+ arr_norm = normalize_template(arr_norm, tmp_sqz, n_obs, axis=axes)
826
+
827
+ ft_temp = be.rfftn(tmp_sqz, out=ft_temp, axes=axes, s=shape)
828
+ temp = _correlate_fts(ft_target, ft_temp, ft_denom, temp, shape, axes)
829
+ temp2 = _correlate_fts(ft_target2, ft_temp, ft_denom, temp2, shape, axes)
1051
830
 
1052
- ft_temp = be.rfftn(tmp_sqz, ft_temp, axes=data_axes)
1053
- ft_denom = be.multiply(ft_target, ft_temp, out=ft_denom)
1054
- temp = be.irfftn(ft_denom, temp, axes=data_axes, s=data_shape)
1055
- ft_denom = be.multiply(ft_target2, ft_temp, out=ft_denom)
1056
- temp2 = be.irfftn(ft_denom, temp2, axes=data_axes, s=data_shape)
831
+ ft_temp = be.rfftn(arr_norm, out=ft_temp, axes=axes, s=shape)
832
+ arr = _correlate_fts(ft_target, ft_temp, ft_denom, arr, shape, axes)
1057
833
 
1058
- ft_temp = rfftn(arr_norm, ft_denom)
1059
- ft_denom = be.multiply(ft_target, ft_temp, out=ft_denom)
1060
- arr = irfftn(ft_denom, arr)
834
+ arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
835
+ arr = norm_mask(arr, score_mask, out=arr)
1061
836
 
1062
- be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
1063
- callback_func(arr, rotation_matrix=rotation)
837
+ callback(arr, rotation_matrix=rotation)
1064
838
 
1065
839
  return callback
1066
840
 
@@ -1078,108 +852,152 @@ def corr_scoring2(
1078
852
  interpolation_order: int,
1079
853
  target_filter: shm_type = None,
1080
854
  template_mask: shm_type = None,
855
+ score_mask: shm_type = None,
1081
856
  ) -> CallbackClass:
1082
857
  template = be.from_sharedarr(template)
1083
858
  ft_target = be.from_sharedarr(ft_target)
1084
859
  inv_denominator = be.from_sharedarr(inv_denominator)
1085
860
  numerator = be.from_sharedarr(numerator)
1086
861
  template_filter = be.from_sharedarr(template_filter)
862
+ score_mask = be.from_sharedarr(score_mask)
1087
863
 
1088
- data_axes = None
1089
- target_batch, template_batch = _get_batch_dim(ft_target, template)
1090
- sqz_cmpl = tuple(1 if i in target_batch else x for i, x in enumerate(fast_ft_shape))
1091
- sqz_slice = tuple(slice(0, 1) if x == 1 else slice(None) for x in sqz_cmpl)
1092
- unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
1093
- if len(target_batch) or len(template_batch):
1094
- batch = (*target_batch, *template_batch)
1095
- data_axes = tuple(i for i in range(len(fast_shape)) if i not in batch)
1096
- unpadded_slice = tuple(
1097
- slice(None) if i in batch else slice(0, x)
1098
- for i, x in enumerate(template.shape)
1099
- )
864
+ tar_batch, tmpl_batch = _get_batch_dim(ft_target, template)
865
+
866
+ nd = len(fast_shape)
867
+ sqz_slice = tuple(slice(0, 1) if i in tar_batch else slice(None) for i in range(nd))
868
+ tmpl_subset = tuple(0 if i in tar_batch else slice(None) for i in range(nd))
869
+
870
+ axes, shape, batched = None, fast_shape, len(tmpl_batch) > 0
871
+ if len(tar_batch) or len(tmpl_batch):
872
+ axes = tuple(i for i in range(nd) if i not in (*tar_batch, *tmpl_batch))
873
+ shape = tuple(fast_shape[i] for i in axes)
874
+
875
+ unpadded_slice = tuple(
876
+ slice(None) if i in (*tar_batch, *tmpl_batch) else slice(0, x)
877
+ for i, x in enumerate(template.shape)
878
+ )
1100
879
 
1101
880
  arr = be.zeros(fast_shape, be._float_dtype)
1102
881
  ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
1103
882
  arr_sqz, ft_sqz = arr[sqz_slice], ft_temp[sqz_slice]
1104
883
 
1105
- if be.size(template_filter) != 1:
1106
- # The filter could be w.r.t the unpadded template
1107
- ret_shape = tuple(
1108
- int(x * y) if x == 1 or y == 1 else y
1109
- for x, y in zip(sqz_cmpl, template_filter.shape)
1110
- )
1111
- ft_sqz = be.zeros(ret_shape, be._complex_dtype)
1112
-
1113
- norm_func, norm_template, mask_sum = normalize_template, False, 1
884
+ n_obs = None
1114
885
  if template_mask is not None:
1115
886
  template_mask = be.from_sharedarr(template_mask)
1116
- norm_template, mask_sum = True, be.sum(
1117
- be.astype(template_mask, be._overflow_safe_dtype),
1118
- axis=data_axes,
1119
- keepdims=True,
1120
- )
1121
- if be.datatype_bytes(template_mask.dtype) == 2:
1122
- norm_func = _normalize_template_overflow_safe
1123
-
1124
- callback_func = conditional_execute(callback, callback is not None)
1125
- norm_template = conditional_execute(norm_func, norm_template)
1126
- norm_numerator = conditional_execute(
1127
- be.subtract, identity, _shape_match(numerator.shape, fast_shape)
1128
- )
1129
- norm_denominator = conditional_execute(
1130
- be.multiply, identity, _shape_match(inv_denominator.shape, fast_shape)
1131
- )
887
+ n_obs = be.sum(template_mask, axis=axes, keepdims=True)
1132
888
 
1133
- _fftargs = {
1134
- "real_dtype": be._float_dtype,
1135
- "cmpl_dtype": be._complex_dtype,
1136
- "fwd_axes": data_axes,
1137
- "inv_axes": data_axes,
1138
- "inv_shape": fast_ft_shape,
1139
- "inv_output_shape": fast_shape,
1140
- "temp_fwd": arr_sqz if _shape_match(ft_sqz.shape, sqz_cmpl) else arr,
1141
- }
1142
-
1143
- # build_fft ignores fwd_shape if temp_fwd is given and serves only for bookkeeping
1144
- _fftargs["fwd_shape"] = _fftargs["temp_fwd"].shape
1145
- rfftn, irfftn = be.build_fft(temp_inv=ft_temp, **_fftargs)
1146
- _ = _fftargs.pop("temp_fwd", None)
889
+ norm_template = conditional_execute(normalize_template, n_obs is not None)
890
+ norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
891
+ norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1,))
892
+ norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
1147
893
 
1148
894
  template_filter_func = _create_filter_func(
1149
895
  arr_shape=template.shape,
1150
- arr_ft_shape=sqz_cmpl,
1151
- arr_filter=template_filter,
1152
- rfftn=rfftn,
1153
- irfftn=irfftn,
1154
- **_fftargs,
896
+ filter_shape=template_filter.shape,
897
+ arr_padded=True,
1155
898
  )
1156
899
 
1157
900
  for index in range(rotations.shape[0]):
1158
901
  be.fill(arr, 0)
1159
902
  rotation = rotations[index]
1160
- arr_sqz, _ = be.rigid_transform(
1161
- arr=template,
903
+ _, _ = be.rigid_transform(
904
+ arr=template[tmpl_subset],
1162
905
  rotation_matrix=rotation,
1163
- out=arr_sqz,
906
+ out=arr_sqz[tmpl_subset],
1164
907
  use_geometric_center=True,
1165
908
  order=interpolation_order,
1166
- cache=True,
1167
- batched=True,
909
+ cache=False,
910
+ batched=batched,
1168
911
  )
1169
912
  arr_norm = template_filter_func(arr_sqz, ft_sqz, template_filter)
1170
- norm_template(arr_norm[unpadded_slice], template_mask, mask_sum, axis=data_axes)
913
+ norm_template(arr_norm[unpadded_slice], template_mask, n_obs, axis=axes)
1171
914
 
1172
- ft_sqz = rfftn(arr_norm, ft_sqz)
1173
- ft_temp = be.multiply(ft_target, ft_sqz, out=ft_temp)
1174
- arr = irfftn(ft_temp, arr)
915
+ ft_sqz = be.rfftn(arr_norm, out=ft_sqz, axes=axes, s=shape)
916
+ arr = _correlate_fts(ft_target, ft_sqz, ft_temp, arr, shape, axes)
1175
917
 
1176
- arr = norm_numerator(arr, numerator, out=arr)
1177
- arr = norm_denominator(arr, inv_denominator, out=arr)
1178
- callback_func(arr, rotation_matrix=rotation)
918
+ arr = norm_sub(arr, numerator, out=arr)
919
+ arr = norm_mul(arr, inv_denominator, out=arr)
920
+ arr = norm_mask(arr, score_mask, out=arr)
921
+
922
+ callback(arr, rotation_matrix=rotation)
1179
923
 
1180
924
  return callback
1181
925
 
1182
926
 
927
+ def _get_batch_dim(target, template):
928
+ target_batch, template_batch = [], []
929
+ for i in range(len(target.shape)):
930
+ if target.shape[i] == 1 and template.shape[i] != 1:
931
+ template_batch.append(i)
932
+ if target.shape[i] != 1 and template.shape[i] == 1:
933
+ target_batch.append(i)
934
+
935
+ return target_batch, template_batch
936
+
937
+
938
+ def _correlate_fts(ft_tar, ft_tmpl, ft_buffer, real_buffer, fast_shape, axes=None):
939
+ ft_buffer = be.multiply(ft_tar, ft_tmpl, out=ft_buffer)
940
+ return be.irfftn(ft_buffer, out=real_buffer, s=fast_shape, axes=axes)
941
+
942
+
943
+ def _create_filter_func(
944
+ arr_shape: Tuple[int],
945
+ filter_shape: BackendArray,
946
+ arr_padded: bool = False,
947
+ axes=None,
948
+ ) -> Callable:
949
+ """
950
+ Configure template filtering function for Fourier transforms.
951
+
952
+ Conceptually we distinguish between three cases. The base case
953
+ is that both template and the corresponding filter have the same
954
+ shape. Padding is used when the template filter is larger than
955
+ the template, for instance to better resolve Fourier filters. Finally
956
+ this function also handles the case when a filter is supposed to be
957
+ broadcasted over the template batch dimension.
958
+
959
+ Parameters
960
+ ----------
961
+ arr_shape : tuple of ints
962
+ Shape of the array to be filtered.
963
+ filter_shape : BackendArray
964
+ Precomputed filter to apply in the frequency domain.
965
+ arr_padded : bool, optional
966
+ Whether the input template is padded and will need to be cropped
967
+ to arr_shape prior to filter applications. Defaults to False.
968
+ axes : tuple of ints, optional
969
+ Axes to perform Fourier transform over.
970
+
971
+ Returns
972
+ -------
973
+ Callable
974
+ Filter function with parameters template, ft_temp and template_filter.
975
+ """
976
+ if filter_shape == (1,):
977
+ return conditional_execute(identity, execute_operation=True)
978
+
979
+ # Default case, all shapes are correctly matched
980
+ def _apply_filter(template, ft_temp, template_filter):
981
+ ft_temp = be.rfftn(template, out=ft_temp, s=template.shape)
982
+ ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
983
+ return be.irfftn(ft_temp, out=template, s=template.shape)
984
+
985
+ if not arr_padded:
986
+ return _apply_filter
987
+
988
+ # Array is padded but filter is w.r.t to the original template
989
+ real_subset = tuple(slice(0, x) for x in arr_shape)
990
+ _template = be.zeros(arr_shape, be._float_dtype)
991
+ _ft_temp = be.zeros(filter_shape, be._complex_dtype)
992
+
993
+ def _apply_filter_subset(template, ft_temp, template_filter):
994
+ _template[:] = template[real_subset]
995
+ template[real_subset] = _apply_filter(_template, _ft_temp, template_filter)
996
+ return template
997
+
998
+ return _apply_filter_subset
999
+
1000
+
1183
1001
  MATCHING_EXHAUSTIVE_REGISTER = {
1184
1002
  "CC": (cc_setup, corr_scoring),
1185
1003
  "LCC": (lcc_setup, corr_scoring),
@@ -1188,6 +1006,6 @@ MATCHING_EXHAUSTIVE_REGISTER = {
1188
1006
  "FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
1189
1007
  "FLC": (flc_setup, flc_scoring),
1190
1008
  "MCC": (mcc_setup, mcc_scoring),
1191
- "batchFLCSpherical": (flcSphericalMask_setup, corr_scoring2),
1009
+ "batchFLCSphericalMask": (flcSphericalMask_setup, corr_scoring2),
1192
1010
  "batchFLC": (flc_setup, flc_scoring2),
1193
1011
  }